Generic Concurrency in Go
Hello Gophers!
In this article, I want to share my thoughts and ideas that I’ve accumulated over time regarding generics in Go, and in particular, concurrency patterns, which now can become more reusable and convenient with the use of generics.
TL;DR
Generics and goroutines (and iterators in the future) are great tools we can leverage to have reusable general purpose concurrent processing in our programs.
In this article we explore the possibilities of combining them together.
Introduction
Let’s quickly touch a surface with some basic context and small examples to see what problem generics solve and how we can fuse it existing concurrency model.
In this article we are going to think a lot about mapping of collections (sets, sequences) of elements. So the mapping is a process that results in a new collection of elements where each element is a result of a call to some function
f()
with the corresponding element from the initial collection.
Pre-Generics era
Let’s define the first simple integer numbers mapping (which in Go snippets we will call transform()
to not confuse with the builtin map
type):
func transform([]int, func(int) int) []int
Sample implementation
func transform(xs []int, f func(int) int) []int {
ret := make([]int, len(xs))
for i, x := range xs {
ret[i] = f(x)
}
return ret
}
An example use of such function would look like this:
// Output: [1, 4, 9]
transform([]int{1, 2, 3}, func(n int) int {
return n * n
})
Now lets assume we want to map integers to strings. That’s easy – we can define transform()
just slightly different:
func transform([]int, func(int) string) []string
So we can use it this way:
// Output: ["1", "2", "3"]
transform([]int{1, 2, 3}, strconv.Itoa)
What about reporting whether a number is odd or even? Just another tiny correction:
func transform([]int, func(int) bool) []bool
So we could use it this way:
// Output: [false, true, false]
transform([]int{1, 2, 3}, func(n int) bool {
return n % 2 == 0
})
Generalising the corrections of transform()
we’ve made above for each use case, we can say that regardless of the types it operates on, it does exactly the same thing over and over again. If we were to generate the code for each type involved using text/template
templates, we could do it like this:
func transform_{{ .A }}_{{ .B }}([]{{ .A }}, func({{ .A }}) {{ .B }}) []{{ .B }}
// transform_int_int([]int, func(int) int) []int
// transform_int_string([]int, func(int) string) []string
// transform_int_bool([]int, func(int) bool) []bool
Actually there were a few nice code generation tools that were doing almost this templating for pre-generic versions of Go. genny is just one example.
Generics era
Thanks to the generics, we now have an ability to parametrize functions and types with type parameters and define tranform()
this way:
func transform[A, B any]([]A, func(A) B) []B
And the implementation changes just a little bit!
func transform[A, B any](xs []A, f func(A) B) []B {
ret := make([]B, len(xs))
for i, x := range xs {
ret[i] = f(x)
}
return ret
}
So we can use it now for any input and output types (assuming we have square(int) int
and isEven(int) bool
defined somewhere in the package):
transform([]int{1, 2, 3}, square) // [1, 4, 9]
transform([]int{1, 2, 3}, strconv.Itoa) // ["1", "2", "3"]
transform([]int{1, 2, 3}, isEven) // [false, true, false]
Concurrent mapping
Okay, now let’s get on to the main subject of this article and focus on concurrency patterns that can benefit from generics.
The x/sync/errgroup
package
Before jumping into lots of coding snippets, let’s make a tiny step aside and look at (very popular) golang.org/x/sync/errgroup Go library. In short, it allows you to start various number of goroutines to perform different tasks and wait for their completion or failure.
It is supposed to be used this way:
// Create workers group and a context which will get canceled if any of the
// tasks fails.
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
return doSomeFun(gctx)
})
g.Go(func() error {
return doEvenMoreFun(gctx)
})
if err := g.Wait(); err != nil {
// handle error
}
The reason I mentioned the package is because, when viewed from a slightly different and a bit generalised perspective, it essentially looks to be the same mapping thing. The package allows you to concurrently map a set of tasks into a corresponding set of results and provides a generalised way for errors handling and propagation, as well as cancellation of subtasks (via context cancellation) if any of them fails.
In this article we want to build something similar, and, as the repeated use of the “generic” word suggests, we will be doing this in a generic way.
Naive implementation
Getting back to the transform()
function. Let’s assume that all the calls to f()
can be done concurrently without breaking our (or anyone else’s) program. Then we can start with this naive concurrent implementation:
func transform[A, B any](as []A, f func(A) B) []B {
bs := make([]B, len(as))
var wg sync.WaitGroup
for i := 0; i < len(as); i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
bs[i] = f(as[i])
}(i)
}
wg.Wait()
return bs
}
That is, we start a goroutine per each element of the input and call f(elem)
. Then we store the result at the corresponding index in the shared slice bs
. No context, no cancellations, no errors even – this one doesn’t look like something very helpful in anything besides pure computation.
Context cancellation
In real world many or even most of the concurrent tasks, especially the i/o related, would be controlled by context.Context
instance. Since there is a context, there could be timeout or cancellation. Let’s think of it this way (here and after I’ll highlight the lines that were added compared to the previous code sample):
func transform[A, B any](
ctx context.Context,
as []A,
f func(context.Context, A) (B, error),
) (
[]B,
error,
) {
bs := make([]B, len(as))
es := make([]error, len(as))
subctx, cancel := context.WithCancel(ctx)
defer cancel()
var wg sync.WaitGroup
for i := 0; i < len(as); i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
bs[i], es[i] = f(subctx, as[i])
if es[i] != nil {
cancel()
}
}(i)
}
wg.Wait()
err := errors.Join(es...)
if err != nil {
return nil, err
}
return bs, nil
}
Now we have one more shared slice es
to store errors potentially returned by f()
. If any goroutine’s f()
fails, we cancel the entire transform()
context and expect every inflight f()
call to respect the cancellation and return as soon as possible.
Limiting concurrency
In reality, we cannot assume too much about f()
implicitly. Users of transform()
might want to limit the number of concurrent calls to f()
. For example, f()
can map a url to the result of an http request. Without any limits we can overwhelm the server or get banned ourselves.
Let’s not think about the parameters structure for now, and just add a parallelism int
argument to the function arguments.
At this point we need to switch from using sync.WaitGroup
to a semaphore chan
, as we want to control the (maximum) number of simultaneously running goroutines as well as to handle the context cancellation, both by using select
.
func transform[A, B any](
ctx context.Context,
parallelism int,
as []A,
f func(context.Context, A) (B, error),
) (
[]B,
error,
) {
bs := make([]B, len(as))
es := make([]error, len(as))
// FIXME: if the given context is already cancelled, no worker will be
// started but the transform() call will return bs, nil.
subctx, cancel := context.WithCancel(ctx)
defer cancel()
sem := make(chan struct{}, parallelism)
sched:
for i := 0; i < len(as); i++ {
// We are checking the sub-context cancellation here, in addition to
// the user-provided context, to handle cases where f() returns an
// error, which leads to the termination of transform.
if subctx.Err() != nil {
break
}
select {
case <-subctx.Done():
break sched
case sem <- struct{}{}:
// Being able to send a tick into the channel means we can start a
// new worker goroutine. This could be either due to the completion
// of a previous goroutine or because the number of started worker
// goroutines is less than the given parallism value.
}
go func(i int) {
defer func() {
// Signal that the element has been processed and the worker
// goroutine has completed.
<-sem
}()
bs[i], es[i] = f(subctx, as[i])
if es[i] != nil {
cancel()
}
}(i)
}
// Since each goroutine reads off one tick from the semaphore before exit,
// filling the channel with artificial ticks makes us sure that all started
// goroutines completed their execution.
//
// FIXME: for the high values of parallelism this loop becomes slow.
for i := 0; i < cap(sem); i++ {
// NOTE: we do not check the user-provided context here because we want
// to return from this function only when all the started worker
// goroutines have completed. This is to avoid surprising users with
// some of the f() function calls still running in the background after
// transform() returns.
//
// This implies f() should respect context cancellation and return as
// soon as its context gets cancelled.
sem <- struct{}{}
}
err := errors.Join(es...)
if err != nil {
return nil, err
}
return bs, nil
}
For this and next iterations of
tranform()
we actually could leave the implementation as it is now and leave both use cases at the mercy off()
implementation. For example, we could just startN
goroutines regardless the concurrency limits and let the user oftransform()
to partially serialise them the way they want to. That would require an overhead of startingN
goroutines instead ofP
(whereP
is the “parallelism” limit, which can be much less thanN
). It also would imply some compute overhead on synchronisation of the goroutines depending on the mechanism used. Since all of this is unnecessary, we proceed with the implementation the hard way, but for many of the cases this complications are optional.Example user based implementation
// Initialised x/time/rate.Limiter instance. var lim *rate.Limiter transform(ctx, as, func(_ context.Context, url string) (int, error) { if err := lim.Wait(ctx); err != nil { return 0, err } // Process url. return 42, nil })
Reusing goroutines
In the previous iteration we were starting a goroutine per each task, but no more parallelism
goroutines at a time. This highlights another interesting option – users might want to have a custom execution context per each goroutine. For example, suppose we have N
tasks with maximum P
running concurrently (and P
can be significantly less than N
). If each task requires some form of resource preparation, such as a large memory allocation, a database session, or maybe a single-threaded Cgo “coroutine”, it would seem logical to prepare only P
resources and reuse them among workers through context.
Again, let’s keep the structure of passing options aside.
func transform[A, B any](
ctx context.Context,
prepare func(context.Context) (context.Context, context.CancelFunc),
parallelism int,
as []A,
f func(context.Context, A) (B, error),
) (
[]B,
error,
) {
bs := make([]B, len(as))
es := make([]error, len(as))
// FIXME: if the given context is already cancelled, no worker will be
// started but the transform() call will return bs, nil.
subctx, cancel := context.WithCancel(ctx)
defer cancel()
sem := make(chan struct{}, parallelism)
wrk := make(chan int)
sched:
for i := 0; i < len(as); i++ {
// We are checking the sub-context cancellation here, in addition to
// the user-provided context, to handle cases where f() returns an
// error, which leads to the termination of transform.
if subctx.Err() != nil {
break
}
select {
case <-subctx.Done():
break sched
case wrk <- i:
// There is an idle worker goroutine that is ready to process the
// next element.
continue
case sem <- struct{}{}:
// Being able to send a tick into the channel means we can start a
// new worker goroutine. This could be either due to the completion
// of a previous goroutine or because the number of started worker
// goroutines is less than the given parallism value.
}
go func(i int) {
defer func() {
// Signal that the element has been processed and the worker
// goroutine has completed.
<-sem
}()
// Capture the subctx from the dispatch loop. This prevents
// overriding it if the given prepare() function is not nil.
subctx := subctx
if prepare != nil {
var cancel context.CancelFunc
subctx, cancel = prepare(subctx)
defer cancel()
}
for {
bs[i], es[i] = f(subctx, as[i])
if es[i] != nil {
cancel()
return
}
var ok bool
i, ok = <-wrk
if !ok {
// Work channel has been closed, which means we will not
// get any new tasks for this worker and can return.
break
}
}
}(i)
}
// Since each goroutine reads off one tick from the semaphore before exit,
// filling the channel with artificial ticks makes us sure that all started
// goroutines completed their execution.
//
// FIXME: for the high values of parallelism this loop becomes slow.
for i := 0; i < cap(sem); i++ {
// NOTE: we do not check the user-provided context here because we want
// to return from this function only when all the started worker
// goroutines have completed. This is to avoid surprising users with
// some of the f() function calls still running in the background after
// transform() returns.
//
// This implies f() should respect context cancellation and return as
// soon as its context gets cancelled.
sem <- struct{}{}
}
err := errors.Join(es...)
if err != nil {
return nil, err
}
return bs, nil
}
At this point we start up to P
goroutines, and distribute tasks across them using non-buffered channel wrk
. The channel is non-buffered because we want to have an immediate runtime “feedback” to know if there are any idle workers at the moment or if we should consider starting a new one. Once all the tasks processed or any of the f()
calls fails, we signal (by doing close(wrk)
) all the started goroutines to return.
As in the previous section, this might be done inside
f()
too, for example, by usingsync.Pool
.f()
could acquire a resource (or create, in case when there are no idle resources) and release it once it’s not needed anymore. Since the set of goroutines is fixed, odds are resources can have a nice CPU locality, so the overhead could be minimal.Example user based implementation
// Note that this snippet assumes `transform()` can limit its concurrency. var pool sync.Pool transform(ctx, 8, as, func(_ context.Context, userID string) (int, error) { sess := pool.Get().(*db.Session) if sess == nil { // Initialise database session. } defer pool.Put(sess) // Process userID. return 42, nil })
Generalisation of transform()
So far our focus has been on mapping slices, which in many cases is enough. However, what if we want to map map
types, or maybe chan
even? Can we map anything that we can range
over? And as in for
loops, do we always need to map values really?
These all are interesting questions which lead us to an idea that we can generalise our concurrent iteration approach. We can have a “low level” function that behaves almost the same but doing a bit less assumptions on its input and output. Then, it will take just a little effort to build a bit more specific transform()
on top of it. Let’s call the function iterate()
and represent its input and output as functions instead of data types. We will pull()
elements from the input and push()
the results back to the user. This way the user of iterate()
would control the way it provides the input elements and the way it handles the results.
We also have to consider what results iterate()
should push to the user. As we plan to make the mapping of input elements optional, (B, error)
doesn’t seem to be the only right and obvious option anymore. This part is really subtle actually, and maybe majority of use cases of the function would benefit of keeping it as it was and returning the error explictily. However, semantically it doesn’t make much sense as the f()
result is only being proxied down to the push()
call without any processing, which means that iterate()
really has no any presumptions on the result. In other words, the result makes sense for the push()
function implementation only, which is given by the user. Additionally, this signature will work better with Go iterators, which we’ll cover in the end of this article. So having this in mind let’s try to reduce number of the return parameters down to one. Since we intend to push results through a function call, we should likely do it in a serialised way. transform()
and later iterate()
have all the needed synchronisation already internally, so this way the user would collect the results without the need for extra synchronisation efforts on their side.
Another thing to cover is the way we were handling errors earlier – we did not map an error to the input element that caused it. It is true that f()
could wrap an error, but it is more clean for f()
to remain unaware of the way it will be called. In other words, f()
should not assume it’s being called as the iterate()
argument. If it was invoked with a single element a
, there is no point to wrap a
into the error, as it’s obvious to the caller that this particular a
caused the error. This principle leads us to another observation – any potential binding of an input element to an error (or any other result) should also occur during the push()
execution. For the same reasons, it is the push()
function that should control the iteration and decide if a faulty result should interrupt the loop.
Additionally, this iterate()
design naturally provides a nice flow control. If the user does something slow in the push()
function, the other worker goroutines will eventually pause processing new elements. This is because they will get blocked by sending their f()
call results into the res
channel, which in turn is being drained by the function that calls push()
.
As the function code listing becomes too big, let’s cover it part by part.
Signature
func iterate[A, B any](
ctx context.Context,
prepare func(context.Context) (context.Context, context.CancelFunc),
parallelism int,
pull func() (A, bool),
f func(context.Context, A) B,
push func(A, B) bool,
) (err error) {
Both arguments and return parameters no longer have input and output slices as it was for the transform()
function previously. Instead, input elements are pulled in by calling the pull()
function and results are pushed back to the user by calling the push()
function. Note that the push()
function returns a bool
parameter that controls the iteration – once false
is returned, no more push()
calls will be made and all ongoing f()
executions will get their context canceled. The iterate()
returns just an error, which can only be non-nil when the iteration is terminated due to the given ctx
cancellation – otherwise there is no way of knowing why the iteration has stopped.
Even though there are just three cases when loop can be terminated:
pull()
returnedfalse
, meaning no more elements to process.push()
returnedfalse
, meaning the user doesn’t need any further results.- the parent
ctx
got cancelled.Without user code complications it’s hard to say whether all of the elements were processed before the parent context got canceled.
Example
Let's assume we want implement concurrent `forEach()` using `iterate()`:func forEach[A any]( ctx context.Context, in []A, f func(context.Context, A) error, ) (err error) { var i int iterate(ctx, nil, 0, func() (_ A, ok bool) { if i == len(in) { return } i++ return in[i-1], true }, f, func(_ A, e error) bool { err = e return e == nil }, ) if err == nil { // BUG: if we returned from `iterate()` call we either processed all // input _or_ ctx got cancelled and the iteration got interrupted. // // Simply checking if ctx.Err() is non-nil here is racy, and may // provide false faulty result in case when we processed all the input // and _then_ context got cancelled. // // On the other hand, checking here if i == len(in) as a condition of // completeness is incorrect, as we might pull the last element to // process and _then_ got interrupted by the context cancelation. // // So if iterate() doesn't return an error, one should track each // element processing state in `f()` call wrapper to correctly // distinguish cases above. err = ctx.Err() } return }
Prologue
// Create sub-context for the dispatch loop goroutine so we can stop it
// once the user wants to stop the iteration.
subctx, cancel := context.WithCancel(ctx)
defer cancel()
// result represents input element A and the result B caused by applying
// the given function f() to A.
type result struct {
a A
b B
}
// loopInfo contains the dispatch loop state.
//
// The dispatch goroutine below signals current goroutine about the loop
// termination by sending loopInfo to the term channel below. The current
// goroutine uses it to understand how many elements have been dispatched
// for processing to decide for how many results to await.
type loopInfo struct {
dispatched int
err error
}
// These channels are receive-only for the current goroutine and send-only
// in the dispatch goroutine. For the sake of readability there is no type
// constraints added.
var (
res = make(chan result)
term = make(chan loopInfo, 1)
)
// This wait group is used to track completion of worker goroutines started
// by the dispatch goroutine.
var wg sync.WaitGroup
In previous versions of transform()
we stored results according to the index of the input element in the results slice, much like bs[i] = f(as[i])
. This is no longer possible with function-based input and output. So, as soon we have a result, we likely need to push()
it to the user immediately. This is why we want to have two goroutines for dispatching the input elements and pushing the results back to the user – while we are dispatching an input, we might already get an output.
Dispatch goroutine
// Start the dispatch goroutine. Its purpose is to control the worker
// goroutines, dispatch input elements among the workers, and eventually
// signal the current goroutine about the dispatch loop termination.
go func() {
// wrk is a channel of input elements. It is send-only for the dispatch
// gorouine and receive-only for the worker goroutines.
wrk := make(chan A)
var loop loopInfo
defer func() {
// Signal the workers there are no more elements to dispatch.
close(wrk)
// Report the dispatch loop state to the parent goroutine.
term <- loop
}()
var workersCount int
// We use a _closed_ channel here to make the select below to always be
// able to receive from it and start up to the given parallelism number
// of goroutines. Once workersCount == parallelism, we set the variable
// to nil so that the select cannot read from it after.
//
// This is needed to:
// - Support the special case when parallelism is 0, so that there are
// no limits on the number of workers.
// - Awoid wasting time "corking" the semaphore channel while waiting
// for all started goroutines to complete, especially if given a
// large parallelism value.
sem := make(chan struct{})
close(sem)
Dispatch loop
for {
if err := subctx.Err(); err != nil {
loop.err = err
return
}
a, ok := pull()
if !ok {
// No more input elements.
return
}
if parallelism != 0 && workersCount == parallelism {
// Prevent starting more workers.
sem = nil
}
select {
case <-subctx.Done():
loop.err = ctx.Err()
return
case wrk <- a:
// There is an idle worker goroutine that is ready to process
// the next element.
loop.dispatched++
continue
case <-sem:
// Being able to _receive_ a tick from the channel means we can
// start a new worker goroutine.
loop.dispatched++
}
workersCount++
wg.Add(1)
Worker goroutine
go func(a A) {
defer wg.Done()
// Capture the subctx from the topmost scope. This prevents
// overriding it if the given prepare() function is not nil.
subctx := subctx
if prepare != nil {
var cancel context.CancelFunc
subctx, cancel = prepare(subctx)
defer cancel()
}
for {
r := result{a: a}
r.b = f(subctx, a)
select {
case res <- r:
case <-subctx.Done():
// If the context is cancelled, it means no more
// results are expected.
return
}
var ok bool
a, ok = <-wrk
if !ok {
break
}
}
}(a)
Results collection
}
}()
collect:
// Wait for the results sent by the worker goroutines.
//
// Note the initial -1 value for the num variable since the number of
// elements pulled and dispatched is unknown yet. We weill be notified by
// the dispatch gorouine once the input ends or the iteration is
// terminated.
for i, num := 0, -1; num == -1 || i < num; {
select {
case <-ctx.Done():
// We need to explicitly handle _parent_ context cancellation here
// because it's an external interruption for us. We ignore the
// dispatch loop termination event and stop to receive and push
// results unconditionally.
if err == nil {
err = ctx.Err()
}
break collect
case res := <-res:
if !push(res.a, res.b) {
// The user wants to stop the iteration. Signal the dispatch
// loop about this. Note that in this case, we ignore the term
// channel message and not return any error.
cancel()
break collect
}
i++
case loop := <-term:
// Dispatch loop has now terminated, and we now know the maximum
// number of results we need receive in this loop.
num = loop.dispatched
err = loop.err
}
}
// NOTE: we unconditionally wait for all goroutines to complete in order to
// return to a clean state. To avoid uninterruptable sleep here users are
// required to respect context cancellation in the provided f().
wg.Wait()
return err
}
As you can see, results are pushed back to the user in the random order – not the way they were pulled in. This is expected and because we process them concurrently.
Here’s an opinionated thought: this combination of
sync.WaitGroup
andsem
channel is a rare example of a justified co-existence of both synchronisation mechanisms in the same code. I believe that in most cases where a channel exists, the wait group is redundant, and vice versa.
And phew, that’s it! It was not easy, but it is what we want. Let’s see how can we use it in the next sections.
Complete code listing
func iterate[A, B any](
ctx context.Context,
prepare func(context.Context) (context.Context, context.CancelFunc),
parallelism int,
pull func() (A, bool),
f func(context.Context, A) B,
push func(A, B) bool,
) (err error) {
// Create sub-context for the dispatch loop goroutine so we can stop it
// once the user wants to stop the iteration.
subctx, cancel := context.WithCancel(ctx)
defer cancel()
// result represents input element A and the result B caused by applying
// the given function f() to A.
type result struct {
a A
b B
}
// loopInfo contains the dispatch loop state.
//
// The dispatch goroutine below signals current goroutine about the loop
// termination by sending loopInfo to the term channel below. The current
// goroutine uses it to understand how many elements have been dispatched
// for processing to decide for how many results to await.
type loopInfo struct {
dispatched int
err error
}
// These channels are receive-only for the current goroutine and send-only
// in the dispatch goroutine. For the sake of readability there is no type
// constraints added.
var (
res = make(chan result)
term = make(chan loopInfo, 1)
)
// This wait group is used to track completion of worker goroutines started
// by the dispatch goroutine.
var wg sync.WaitGroup
// Start the dispatch goroutine. Its purpose is to control the worker
// goroutines, dispatch input elements among the workers, and eventually
// signal the current goroutine about the dispatch loop termination.
go func() {
// wrk is a channel of input elements. It is send-only for the dispatch
// gorouine and receive-only for the worker goroutines.
wrk := make(chan A)
var loop loopInfo
defer func() {
// Signal the workers there are no more elements to dispatch.
close(wrk)
// Report the dispatch loop state to the parent goroutine.
term <- loop
}()
var workersCount int
// We use a _closed_ channel here to make the select below to always be
// able to receive from it and start up to the given parallelism number
// of goroutines. Once workersCount == parallelism, we set the variable
// to nil so that the select cannot read from it after.
//
// This is needed to:
// - Support the special case when parallelism is 0, so that there are
// no limits on the number of workers.
// - Awoid wasting time "corking" the semaphore channel while waiting
// for all started goroutines to complete, especially if given a
// large parallelism value.
sem := make(chan struct{})
close(sem)
for {
if err := subctx.Err(); err != nil {
loop.err = err
return
}
a, ok := pull()
if !ok {
// No more input elements.
return
}
if parallelism != 0 && workersCount == parallelism {
// Prevent starting more workers.
sem = nil
}
select {
case <-subctx.Done():
loop.err = ctx.Err()
return
case wrk <- a:
// There is an idle worker goroutine that is ready to process
// the next element.
loop.dispatched++
continue
case <-sem:
// Being able to _receive_ a tick from the channel means we can
// start a new worker goroutine.
loop.dispatched++
}
workersCount++
wg.Add(1)
go func(a A) {
defer wg.Done()
// Capture the subctx from the topmost scope. This prevents
// overriding it if the given prepare() function is not nil.
subctx := subctx
if prepare != nil {
var cancel context.CancelFunc
subctx, cancel = prepare(subctx)
defer cancel()
}
for {
r := result{a: a}
r.b = f(subctx, a)
select {
case res <- r:
case <-subctx.Done():
// If the context is cancelled, it means no more
// results are expected.
return
}
var ok bool
a, ok = <-wrk
if !ok {
break
}
}
}(a)
}
}()
collect:
// Wait for the results sent by the worker goroutines.
//
// Note the initial -1 value for the num variable since the number of
// elements pulled and dispatched is unknown yet. We weill be notified by
// the dispatch gorouine once the input ends or the iteration is
// terminated.
for i, num := 0, -1; num == -1 || i < num; {
select {
case <-ctx.Done():
// We need to explicitly handle _parent_ context cancellation here
// because it's an external interruption for us. We ignore the
// dispatch loop termination event and stop to receive and push
// results unconditionally.
if err == nil {
err = ctx.Err()
}
break collect
case res := <-res:
if !push(res.a, res.b) {
// The user wants to stop the iteration. Signal the dispatch
// loop about this. Note that in this case, we ignore the term
// channel message and not return any error.
cancel()
break collect
}
i++
case loop := <-term:
// Dispatch loop has now terminated, and we now know the maximum
// number of results we need receive in this loop.
num = loop.dispatched
err = loop.err
}
}
// NOTE: we unconditionally wait for all goroutines to complete in order to
// return to a clean state. To avoid uninterruptable sleep here users are
// required to respect context cancellation in the provided f().
wg.Wait()
return err
}
Using iterate()
to transform()
To test how the generic iteration function can solve the mapping problem let’s re-implement transform()
using it. It obviously now looks much shorter as we moved the concurrent iteration complexity away from it and can focus basically just on storing mapping results.
func transform[A, B any](
ctx context.Context,
prepare func(context.Context) (context.Context, context.CancelFunc),
parallelism int,
as []A,
f func(context.Context, A) (B, error),
) (
[]B, error,
) {
bs := make([]B, len(as))
var (
i int
err1 error
)
err0 := iterate(ctx, prepare, parallelism,
func() (int, bool) {
i++
return i - 1, i <= len(as)
},
func(ctx context.Context, i int) (err error) {
bs[i], err = f(ctx, as[i])
return
},
func(i int, err error) bool {
err1 = err
return err == nil
},
)
if err := errors.Join(err0, err1); err != nil {
return nil, err
}
return bs, nil
}
Reimplemented errgroup
To conclude the analogy with the errgroup
package, let’s try to implement something similar using iterate()
approach.
type taskFunc func(context.Context) error
func errgroup(ctx context.Context) (
g func(taskFunc),
wait func() error,
) {
task := make(chan taskFunc)
done := make(chan struct{})
var (
err error
failure error
)
go func() {
defer close(done)
// NOTE: we ignore the context preparation here as we don't need it. We
// also don't limit amount of goroutines running at the same time -- we
// want each task to start to be executed as soon as possible.
err = iterate(ctx, nil, 0,
func() (f taskFunc, ok bool) {
f, ok = <-task
return
},
func(ctx context.Context, f taskFunc) error {
return f(ctx)
},
func(_ taskFunc, err error) bool {
if err != nil {
// Cancel the group work and stop taking new tasks.
failure = err
return false
}
return true
},
)
}()
g = func(fn taskFunc) {
// If wait() wasn't called yet, but a previously scheduled task has
// failed already, we should ignore the task and avoid deadlock here.
select {
case task <- fn:
case <-done:
}
}
wait = func() error {
close(task)
<-done
return errors.Join(err, failure)
}
return
}
So the use of the function would be very similar to errgroup
package:
// Create the workers group and a context which will be canceled if any of the
// tasks fails.
g, wait := errgroup(ctx)
g(func(ctx context.Context) error {
return doSomeFun(gctx)
})
g(func(ctx context.Context) error {
return doEvenMoreFun(gctx)
})
if err := wait(); err != nil {
// handle error
}
Go Iterators
Let’s briefly look at the near future of Go in relation to the ideas implemented above.
With the recent (as of Go 1.22) range over functions experiment it is possible to do a usual range
over functions that are compatible with sequences iterator types defined by the iter
package. A quite new concept in Go which is hopefully being shipped in the future versions of Go as part of the standard library. For more information please read range over func proposal as well as predestining article on coroutines in Go by Russ Cox, which the experimental iter
package is built ontop.
Adjusting the iterate()
to be iter
compatible is easy as pie:
func iterate[A, B any](
ctx context.Context,
prepare func(context.Context) (context.Context, context.CancelFunc),
parallelism int,
seq iter.Seq[A],
f func(context.Context, A) B,
) iter.Seq2[A, B] {
return func(yield func(A, B) bool) {
pull, stop := iter.Pull(seq)
defer stop()
iterate(ctx, prepare, parallelism, pull, f, yield)
}
}
Enabling that experiment allows us to do an amazing thing – to iterate over the results of concurrently processed elements of a sequence in the regular for
loop!
// Assuming the standard library supports iterators.
seq := slices.Sequence([]int{1, 2, 3})
// Output: [1, 4, 9]
for a, b := range iterate(ctx, nil, 0, seq, square) {
fmt.Println(a, b)
}
Conclusion
I wish this was a part of the Go standard library.
Initially I wanted to have this first sentence to be the only content for this conclusion section, but probably at least a few words still should be said why. I believe such general purpose utilities can be much better conveyed and accepted by projects if majority of the community agrees on how the utilities designed and built. Of course we can have some libraries solving similar problems, but in my opinion the more different libraries we have the more disagreement in the community we may get about what, when and how to use them. For some cases there is nothing wrong to have widely different approaches and implementations, but for some cases it can also mean not having a complete solution at all. Very often libraries initially get born as a much more specific solution than needed to be widely adopted, and to be really general purpose solution the design, API and then implementation should be well discussed way before the actual work takes place. This is how OSS foundations solve similar problems or Go team in case of Go. Having something for such concurrent/asynchronous processing feels to be a natural evolvement after getting generic slices package and later coroutines and iterators.