Bladeren bron

Add v6 API to fix page pool leaks on context cancel

* Extend Pager interface with `Abort()` method.
* Add "cancel" function and "done" channel to Pager struct for Abort logic.
* Add explicit page pool size to Pager struct.
* Fix page pool leak on context cancel.
Jonathan Storm 3 maanden geleden
bovenliggende
commit
d5aff46571
3 gewijzigde bestanden met toevoegingen van 445 en 0 verwijderingen
  1. 184 0
      v6/depager.go
  2. 258 0
      v6/depager_test.go
  3. 3 0
      v6/go.mod

+ 184 - 0
v6/depager.go

@@ -0,0 +1,184 @@
+/*
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/.
+ */
+
+package depager
+
+import (
+	"context"
+	"fmt"
+)
+
+/*
+The `Page` interface must wrap server responses. This
+allows pagers to calculate page sizes and iterate over
+response aggregates.
+
+If the underlying value of this interface is `nil` (e.g. a
+nil pointer to a struct or a nil slice), `Elems()` will
+panic.
+*/
+type Page[T any] interface {
+	// Elems must return the items from this page
+	Elems() []T
+
+	// URI must return the URI associated with this page
+	URI() string
+
+	// Count must return the total number of items being paged
+	Count() uint64
+}
+
+// Exposes the part of the client that depager understands.
+type Client[T any] interface {
+	// NextPage returns the next page or it returns an error
+	NextPage(
+		page Page[T],
+		offset uint64, // item offset at which to start page
+	) (err error)
+}
+
+type Pager[T any] interface {
+	// Iter is intended to be used in a for-range loop
+	Iter() <-chan T
+
+	// IterPages iterates over whole pages rather than items
+	IterPages() <-chan Page[T]
+
+	// LastErr must return the first error encountered, if any
+	LastErr() error
+
+	// Abort causes the pager to relinquish all pages back to
+	// the page pool and stop all running goroutines.
+	Abort() error
+}
+
+func NewPager[T any](
+	ctx context.Context,
+	c Client[T],
+	pagePool chan Page[T],
+) Pager[T] {
+	if len(pagePool) == 0 {
+		panic("new pager: provided page pool is empty")
+	}
+	var pageSize uint64
+	pg := <-pagePool
+	pageSize = uint64(cap(pg.Elems()))
+	pagePool <- pg
+
+	ctx2, cancel := context.WithCancel(ctx)
+	done := make(chan struct{})
+
+	return &pager[T]{
+		ctx:      ctx2,
+		cancel:   cancel,
+		done:     done,
+		client:   c,
+		n:        pageSize,
+		pagePool: pagePool,
+		poolSize: len(pagePool),
+	}
+}
+
+/*
+Retrieve n items in the range [m*n, m*n + n - 1], inclusive.
+We keep len(pagePool) pages buffered.
+*/
+type pager[T any] struct {
+	ctx      context.Context
+	cancel   context.CancelFunc
+	done     chan struct{} // Notify Abort when finished.
+	client   Client[T]
+	m        uint64
+	n        uint64
+	err      error
+	pagePool chan Page[T]
+	poolSize int
+	cnt      uint64
+}
+
+func (p *pager[T]) iteratePages() <-chan Page[T] {
+	ch := make(chan Page[T], len(p.pagePool))
+	go func() {
+		defer close(ch)
+		var page Page[T]
+		for {
+			if p.ctx.Err() != nil {
+				break
+			}
+			page = <-p.pagePool
+			err := p.client.NextPage(page, p.m*p.n)
+			if err != nil {
+				p.err = err
+				p.pagePool <- page
+				return
+			}
+			if p.cnt == 0 {
+				p.cnt = page.Count()
+			}
+			ch <- page
+
+			if (p.m*p.n + p.n) >= p.cnt {
+				return
+			}
+			p.m++
+		}
+	}()
+	return ch
+}
+
+func (p *pager[T]) IterPages() <-chan Page[T] {
+	ch := make(chan Page[T], p.n)
+	go func() {
+		defer close(p.done)
+		defer close(ch)
+		for page := range p.iteratePages() {
+			if p.ctx.Err() != nil {
+				p.pagePool <- page
+				break
+			}
+			if p.err != nil {
+				p.err = fmt.Errorf("pager: iterate pages: %s", p.err)
+				p.pagePool <- page
+				return
+			}
+			ch <- page
+		}
+	}()
+	return ch
+}
+
+func (p *pager[T]) Iter() <-chan T {
+	ch := make(chan T, p.n)
+	go func() {
+		defer close(p.done)
+		defer close(ch)
+		for page := range p.iteratePages() {
+			if p.ctx.Err() != nil {
+				p.pagePool <- page
+				break
+			}
+			for _, i := range page.Elems() {
+				ch <- i
+			}
+			p.pagePool <- page
+			if p.err != nil {
+				p.err = fmt.Errorf("pager: iterate items: %s", p.err)
+				return
+			}
+		}
+	}()
+	return ch
+}
+
+func (p *pager[T]) LastErr() error {
+	return p.err
+}
+
+func (p *pager[T]) Abort() error {
+	p.cancel()
+	<-p.done
+	return p.ctx.Err()
+}

+ 258 - 0
v6/depager_test.go

@@ -0,0 +1,258 @@
+package depager
+
+import (
+	"context"
+	"fmt"
+	"testing"
+)
+
+type NoopClient[T any] struct {
+	err   error
+	pages []*Aggr[T]
+	m     int
+	cnt   uint64
+}
+
+func (c *NoopClient[T]) NextPage(
+	page Page[T],
+	_offset uint64,
+) (err error) {
+	if len(c.pages) == 0 {
+		return
+	}
+	if c.m >= len(c.pages) {
+		err = fmt.Errorf("client: next page: exceeded max pages")
+		return
+	}
+	src := *c.pages[c.m]
+	dst := *page.(*Aggr[T])
+	dst = dst[:min(cap(dst), len(src))]
+	copy(dst, src)         // update values
+	*page.(*Aggr[T]) = dst // update slice
+
+	AggrCount = c.cnt
+	err = c.err
+	c.m++
+	return
+}
+
+func NewNoopClient[T any](
+	cnt int,
+	err error,
+	pages []*Aggr[T],
+) Client[T] {
+	return &NoopClient[T]{
+		cnt:   uint64(cnt),
+		err:   err,
+		pages: pages,
+	}
+}
+
+var AggrCount uint64
+
+type Aggr[T any] []T
+
+func (a *Aggr[T]) Elems() []T {
+	return []T(*a)
+}
+
+func (a *Aggr[T]) URI() string {
+	return ""
+}
+
+func (a *Aggr[T]) Count() uint64 {
+	return AggrCount
+}
+
+func TestUsingNoopClient(t *testing.T) {
+	client := NewNoopClient[any](1, nil, []*Aggr[any]{{}})
+	pagePool := make(chan Page[any], 1)
+	for i := 0; i < cap(pagePool); i++ {
+		tmp := Aggr[any](make([]any, 0, 1))
+		pagePool <- &tmp
+	}
+	pager := NewPager(context.Background(), client, pagePool)
+	for range pager.Iter() {
+	}
+	if err := pager.LastErr(); err != nil {
+		t.Errorf("unexpected error in pager with noop client: %v", err)
+	}
+}
+
+func TestNoopClientReturnsError(t *testing.T) {
+	client := NewNoopClient[any](0, fmt.Errorf("whomp"),
+		[]*Aggr[any]{{}},
+	)
+	pagePool := make(chan Page[any], 1)
+	for i := 0; i < cap(pagePool); i++ {
+		tmp := Aggr[any](make([]any, 0))
+		pagePool <- &tmp
+	}
+	pager := NewPager(context.Background(), client, pagePool)
+	for range pager.Iter() {
+	}
+	if err := pager.LastErr(); err == nil {
+		t.Errorf("unexpected success: %v", err)
+	}
+}
+
+func TestClientReturnsNonemptyPage(t *testing.T) {
+	itemCount := 3
+	client := NewNoopClient[any](itemCount, nil,
+		[]*Aggr[any]{{1, 2}, {3}},
+	)
+	pagePool := make(chan Page[any], 1)
+	for i := 0; i < cap(pagePool); i++ {
+		tmp := Aggr[any](make([]any, 0, 2))
+		pagePool <- &tmp
+	}
+	pager := NewPager(context.Background(), client, pagePool)
+	var elem int
+	for e := range pager.Iter() {
+		elem = e.(int)
+	}
+	if err := pager.LastErr(); err != nil {
+		t.Errorf("unexpected error in pager: %v", err)
+	}
+	if elem != 3 {
+		t.Errorf("unexpected value: '%v'", elem)
+	}
+}
+
+func TestClientReturnsNonemptyPage2(t *testing.T) {
+	itemCount := 3
+	client := NewNoopClient[any](itemCount, nil,
+		[]*Aggr[any]{{1, 2}, {3}},
+	)
+	pagePool := make(chan Page[any], 1)
+	for i := 0; i < cap(pagePool); i++ {
+		tmp := Aggr[any](make([]any, 0, 2))
+		pagePool <- &tmp
+	}
+	pager := NewPager(context.Background(), client, pagePool)
+	var elem int
+	var i int
+	for p := range pager.IterPages() {
+		elem = p.Elems()[0].(int)
+		i++
+		pagePool <- p
+	}
+	if err := pager.LastErr(); err != nil {
+		t.Errorf("unexpected error in pager: %v", err)
+	}
+	if elem != 3 {
+		t.Errorf("unexpected value: '%v'", elem)
+	}
+}
+
+func TestClientReturnsFewerPagesThanExpected(t *testing.T) {
+	pageSize := 1
+	itemCount := pageSize + 1
+	client := NewNoopClient[any](itemCount, nil,
+		[]*Aggr[any]{{0}},
+	)
+	pagePool := make(chan Page[any], 1)
+	for i := 0; i < cap(pagePool); i++ {
+		tmp := Aggr[any](make([]any, 0, 1))
+		pagePool <- &tmp
+	}
+	pager := NewPager(context.Background(), client, pagePool)
+	for range pager.Iter() {
+	}
+	if err := pager.LastErr(); err == nil {
+		t.Errorf("unexpected success in pager: %v", err)
+	}
+}
+
+func TestClientAbortsPagingItems(t *testing.T) {
+	pageSize := 2
+	itemCount := pageSize + 1
+	client := NewNoopClient[any](itemCount, nil,
+		[]*Aggr[any]{{0, 1}, {2}},
+	)
+	pagePool := make(chan Page[any], 2)
+	pg := Aggr[any](make([]any, 0, 2))
+	pagePool <- &pg
+	pg = Aggr[any](make([]any, 0, 2))
+	pagePool <- &pg
+
+	pager := NewPager(context.Background(), client, pagePool)
+	for range pager.Iter() {
+		break
+	}
+	if err := pager.Abort(); err == nil {
+		t.Errorf("unexpected result of Abort: %v", err)
+	}
+	if err := pager.LastErr(); err != nil {
+		t.Errorf("unexpected error in pager: %v", err)
+	}
+	if ps := len(pagePool); ps != 2 {
+		t.Errorf("unexpected number of pages in page pool: %d", ps)
+	}
+}
+
+func clearChannel(
+	pool chan Page[any],
+	ch <-chan Page[any],
+) {
+}
+
+func TestClientAbortsPaging(t *testing.T) {
+	cases := []struct{ poolLen, poolCap int }{
+		{1, 2},
+		{2, 2},
+	}
+	for _, c := range cases {
+		performAbortTest(t, c.poolLen, c.poolCap)
+	}
+}
+
+func performAbortTest(t *testing.T, poolLen, poolCap int) {
+	// Setup
+	pageSize := 2
+	itemCount := pageSize + 1
+	client := NewNoopClient[any](itemCount, nil,
+		[]*Aggr[any]{{0, 1}, {2}},
+	)
+	pagePool := make(chan Page[any], poolCap)
+	var pg Aggr[any]
+
+	for i := 0; i < poolLen; i++ {
+		pg = Aggr[any](make([]any, 0, 2))
+		pagePool <- &pg
+	}
+
+	// Abort paging prematurely
+	pager := NewPager(context.Background(), client, pagePool)
+	ch := pager.IterPages()
+	for p := range ch {
+		pagePool <- p
+		break
+	}
+	if err := pager.Abort(); err == nil {
+		t.Errorf("unexpected result of Abort: %v", err)
+	}
+
+	// Return pages to pool
+	for i := 0; i < len(ch); i++ {
+		pagePool <- <-ch
+	}
+
+	// Test our assumptions
+	if err := pager.LastErr(); err != nil {
+		t.Errorf("unexpected error in pager: %v", err)
+	}
+	if len(pagePool) != poolLen {
+		t.Errorf("unexpected number of pages in page pool: %d", len(pagePool))
+	}
+	tmp := make(chan Page[any], poolLen)
+	pages := make(map[Page[any]]struct{})
+	for i := 0; i < len(pagePool); i++ {
+		p := <-pagePool
+		pages[p] = struct{}{}
+		tmp <- p
+	}
+	if len(pages) != len(tmp) {
+		t.Errorf("recovered pages are not unique: page pool length %d is not equal to the number of unique pages %d", len(pagePool), len(pages))
+	}
+}

+ 3 - 0
v6/go.mod

@@ -0,0 +1,3 @@
+module idio.link/go/depager/v6
+
+go 1.21