Skip to content
45 changes: 37 additions & 8 deletions iter/iter.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package iter

import (
"context"
"runtime"
"sync/atomic"

"github.com/sourcegraph/conc"
"github.com/sourcegraph/conc/pool"
)

// defaultMaxGoroutines returns the default maximum number of
Expand Down Expand Up @@ -57,29 +58,57 @@ func ForEachIdx[T any](input []T, f func(int, *T)) { Iterator[T]{}.ForEachIdx(in
// ForEachIdx is the same as ForEach except it also provides the
// index of the element to the callback.
func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) {
_ = iter.ForEachIdxCtx(context.Background(), input, func(_ context.Context, idx int, input *T) error {
f(idx, input)
return nil
})
}

func ForEachIdxCtx[T any](octx context.Context, input []T, f func(context.Context, int, *T) error) error {
return Iterator[T]{}.ForEachIdxCtx(octx, input, f)
}

func ForEachCtx[T any](octx context.Context, input []T, f func(context.Context, *T) error) error {
return Iterator[T]{}.ForEachCtx(octx, input, f)
}

func (iter Iterator[T]) ForEachCtx(octx context.Context, input []T, f func(context.Context, *T) error) error {
return iter.ForEachIdxCtx(octx, input, func(ctx context.Context, _ int, input *T) error {
return f(ctx, input)
})
}

func (iter Iterator[T]) ForEachIdxCtx(octx context.Context, input []T, f func(context.Context, int, *T) error) error {
if iter.MaxGoroutines == 0 {
// iter is a value receiver and is hence safe to mutate
iter.MaxGoroutines = defaultMaxGoroutines()
}

numInput := len(input)
if iter.MaxGoroutines > numInput {
if iter.MaxGoroutines > numInput && numInput > 0 {
// No more concurrent tasks than the number of input items.
iter.MaxGoroutines = numInput
}

var idx atomic.Int64
// Create the task outside the loop to avoid extra closure allocations.
task := func() {
task := func(ctx context.Context) error {
i := int(idx.Add(1) - 1)
for ; i < numInput; i = int(idx.Add(1) - 1) {
f(i, &input[i])
for ; i < numInput && ctx.Err() == nil; i = int(idx.Add(1) - 1) {
if err := f(ctx, i, &input[i]); err != nil {
return err
}
}
return nil
}

var wg conc.WaitGroup
runner := pool.New().
WithContext(octx).
WithCancelOnError().
WithFirstError().
WithMaxGoroutines(iter.MaxGoroutines)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to not add a dependency on pool in the iter package. Because the iter package knows the number of inputs and the number of outputs in advance, it can be considerably more efficient than the pool package, which must work for an unbounded number of iterations.

I think it would be good to reconcile this PR with the patterns in #104. In particular, the FailFast flag should mean similar things iterating with ForEachErr and ForEachCtx.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abstractions built into ContextPool offered a simple way to achieve the functionality I was looking for with regard to context cancellation without having to duplicate code. Both iter and pool used waitgroup underneath. I'm not too sure what efficiency gains are made by having two separate implementations but I'd be happy to switch it back to original setup.

Additionally, as a further argument for using ContextPool, the FailFast behavior seems like it could be easily implemented by using the new bool to optionally call WithFirstError() on the underlying pool. Right now I call that by default assuming the caller is only interested in the first error when using a context.

I could add the *Err variants as extensions to what I have here that just return an error but don't require a context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the FailFast behavior seems like it could be easily implemented by using the new bool to optionally call WithFirstError() on the underlying pool

Related: #118

The abstractions built into ContextPool offered a simple way to achieve the functionality I was looking for with regard to context cancellation without having to duplicate code

You might be right. The simplicity of using a Pool to back the iter package is likely more valuable than the minor efficiency gains we get from knowing the size of the set in advance (which basically boil down to allocating in advance and not needing a mutex).

Let me noodle on the design a bit and get back to ya. I'll probably open a draft that unifies this with #104 and the pool package in general, maybe just replacing the Iterator structs with configured pools.

for i := 0; i < iter.MaxGoroutines; i++ {
wg.Go(task)
runner.Go(task)
}
wg.Wait()
return runner.Wait()
}
29 changes: 22 additions & 7 deletions iter/map.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package iter

import (
"context"
"sync"

"github.com/sourcegraph/conc/internal/multierror"
Expand All @@ -25,9 +26,8 @@ func Map[T, R any](input []T, f func(*T) R) []R {
//
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) Map(input []T, f func(*T) R) []R {
res := make([]R, len(input))
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
res[i] = f(t)
res, _ := m.MapErr(input, func(t *T) (R, error) {
return f(t), nil
})
return res
}
Expand All @@ -47,19 +47,34 @@ func MapErr[T, R any](input []T, f func(*T) (R, error)) ([]R, error) {
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) {
var (
res = make([]R, len(input))
errMux sync.Mutex
errs error
)
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
var err error
res[i], err = f(t)
// MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapErrCtx
res, _ := m.MapErrCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) {
ires, err := f(t)
if err != nil {
errMux.Lock()
// TODO: use stdlib errors once multierrors land in go 1.20
errs = multierror.Join(errs, err)
errMux.Unlock()
}
return ires, nil
})
return res, errs
}

func MapErrCtx[T, R any](octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {
return Mapper[T, R]{}.MapErrCtx(octx, input, f)
}

func (m Mapper[T, R]) MapErrCtx(octx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {
var (
res = make([]R, len(input))
)
return res, Iterator[T](m).ForEachIdxCtx(octx, input, func(ctx context.Context, i int, t *T) error {
var err error
res[i], err = f(ctx, t)
return err
})
}