Browse Source

db: use `context` and go-mockgen for `LFSStore` (#7038)

Joe Chen 2 years ago
parent
commit
94059f291d

+ 19 - 18
internal/db/lfs.go

@@ -5,6 +5,7 @@
 package db
 
 import (
+	"context"
 	"fmt"
 	"time"
 
@@ -19,24 +20,24 @@ import (
 // NOTE: All methods are sorted in alphabetical order.
 type LFSStore interface {
 	// CreateObject creates a LFS object record in database.
-	CreateObject(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error
-	// GetObjectByOID returns the LFS object with given OID. It returns ErrLFSObjectNotExist
-	// when not found.
-	GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error)
-	// GetObjectsByOIDs returns LFS objects found within "oids". The returned list could have
-	// less elements if some oids were not found.
-	GetObjectsByOIDs(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error)
+	CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error
+	// GetObjectByOID returns the LFS object with given OID. It returns
+	// ErrLFSObjectNotExist when not found.
+	GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error)
+	// GetObjectsByOIDs returns LFS objects found within "oids". The returned list
+	// could have less elements if some oids were not found.
+	GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error)
 }
 
 var LFS LFSStore
 
 // LFSObject is the relation between an LFS object and a repository.
 type LFSObject struct {
-	RepoID    int64           `gorm:"PRIMARY_KEY;AUTO_INCREMENT:false"`
-	OID       lfsutil.OID     `gorm:"PRIMARY_KEY;COLUMN:oid"`
-	Size      int64           `gorm:"NOT NULL"`
-	Storage   lfsutil.Storage `gorm:"NOT NULL"`
-	CreatedAt time.Time       `gorm:"NOT NULL"`
+	RepoID    int64           `gorm:"primary_key;auto_increment:false"`
+	OID       lfsutil.OID     `gorm:"primary_key;column:oid"`
+	Size      int64           `gorm:"not null"`
+	Storage   lfsutil.Storage `gorm:"not null"`
+	CreatedAt time.Time       `gorm:"not null"`
 }
 
 var _ LFSStore = (*lfs)(nil)
@@ -45,14 +46,14 @@ type lfs struct {
 	*gorm.DB
 }
 
-func (db *lfs) CreateObject(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
+func (db *lfs) CreateObject(ctx context.Context, repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
 	object := &LFSObject{
 		RepoID:  repoID,
 		OID:     oid,
 		Size:    size,
 		Storage: storage,
 	}
-	return db.DB.Create(object).Error
+	return db.WithContext(ctx).Create(object).Error
 }
 
 type ErrLFSObjectNotExist struct {
@@ -72,9 +73,9 @@ func (ErrLFSObjectNotExist) NotFound() bool {
 	return true
 }
 
-func (db *lfs) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error) {
+func (db *lfs) GetObjectByOID(ctx context.Context, repoID int64, oid lfsutil.OID) (*LFSObject, error) {
 	object := new(LFSObject)
-	err := db.Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
+	err := db.WithContext(ctx).Where("repo_id = ? AND oid = ?", repoID, oid).First(object).Error
 	if err != nil {
 		if err == gorm.ErrRecordNotFound {
 			return nil, ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": oid}}
@@ -84,13 +85,13 @@ func (db *lfs) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error)
 	return object, err
 }
 
-func (db *lfs) GetObjectsByOIDs(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
+func (db *lfs) GetObjectsByOIDs(ctx context.Context, repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
 	if len(oids) == 0 {
 		return []*LFSObject{}, nil
 	}
 
 	objects := make([]*LFSObject, 0, len(oids))
-	err := db.Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error
+	err := db.WithContext(ctx).Where("repo_id = ? AND oid IN (?)", repoID, oids).Find(&objects).Error
 	if err != nil && err != gorm.ErrRecordNotFound {
 		return nil, err
 	}

+ 32 - 40
internal/db/lfs_test.go

@@ -5,16 +5,18 @@
 package db
 
 import (
+	"context"
 	"testing"
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"gogs.io/gogs/internal/errutil"
 	"gogs.io/gogs/internal/lfsutil"
 )
 
-func Test_lfs(t *testing.T) {
+func TestLFS(t *testing.T) {
 	if testing.Short() {
 		t.Skip()
 	}
@@ -30,16 +32,14 @@ func Test_lfs(t *testing.T) {
 		name string
 		test func(*testing.T, *lfs)
 	}{
-		{"CreateObject", test_lfs_CreateObject},
-		{"GetObjectByOID", test_lfs_GetObjectByOID},
-		{"GetObjectsByOIDs", test_lfs_GetObjectsByOIDs},
+		{"CreateObject", lfsCreateObject},
+		{"GetObjectByOID", lfsGetObjectByOID},
+		{"GetObjectsByOIDs", lfsGetObjectsByOIDs},
 	} {
 		t.Run(tc.name, func(t *testing.T) {
 			t.Cleanup(func() {
 				err := clearTables(t, db.DB, tables...)
-				if err != nil {
-					t.Fatal(err)
-				}
+				require.NoError(t, err)
 			})
 			tc.test(t, db)
 		})
@@ -49,67 +49,59 @@ func Test_lfs(t *testing.T) {
 	}
 }
 
-func test_lfs_CreateObject(t *testing.T, db *lfs) {
+func lfsCreateObject(t *testing.T, db *lfs) {
+	ctx := context.Background()
+
 	// Create first LFS object
 	repoID := int64(1)
 	oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
-	err := db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal)
-	if err != nil {
-		t.Fatal(err)
-	}
+	err := db.CreateObject(ctx, repoID, oid, 12, lfsutil.StorageLocal)
+	require.NoError(t, err)
 
 	// Get it back and check the CreatedAt field
-	object, err := db.GetObjectByOID(repoID, oid)
-	if err != nil {
-		t.Fatal(err)
-	}
+	object, err := db.GetObjectByOID(ctx, repoID, oid)
+	require.NoError(t, err)
 	assert.Equal(t, db.NowFunc().Format(time.RFC3339), object.CreatedAt.UTC().Format(time.RFC3339))
 
 	// Try create second LFS object with same oid should fail
-	err = db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal)
+	err = db.CreateObject(ctx, repoID, oid, 12, lfsutil.StorageLocal)
 	assert.Error(t, err)
 }
 
-func test_lfs_GetObjectByOID(t *testing.T, db *lfs) {
+func lfsGetObjectByOID(t *testing.T, db *lfs) {
+	ctx := context.Background()
+
 	// Create a LFS object
 	repoID := int64(1)
 	oid := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
-	err := db.CreateObject(repoID, oid, 12, lfsutil.StorageLocal)
-	if err != nil {
-		t.Fatal(err)
-	}
+	err := db.CreateObject(ctx, repoID, oid, 12, lfsutil.StorageLocal)
+	require.NoError(t, err)
 
 	// We should be able to get it back
-	_, err = db.GetObjectByOID(repoID, oid)
-	if err != nil {
-		t.Fatal(err)
-	}
+	_, err = db.GetObjectByOID(ctx, repoID, oid)
+	require.NoError(t, err)
 
 	// Try to get a non-existent object
-	_, err = db.GetObjectByOID(repoID, "bad_oid")
+	_, err = db.GetObjectByOID(ctx, repoID, "bad_oid")
 	expErr := ErrLFSObjectNotExist{args: errutil.Args{"repoID": repoID, "oid": lfsutil.OID("bad_oid")}}
 	assert.Equal(t, expErr, err)
 }
 
-func test_lfs_GetObjectsByOIDs(t *testing.T, db *lfs) {
+func lfsGetObjectsByOIDs(t *testing.T, db *lfs) {
+	ctx := context.Background()
+
 	// Create two LFS objects
 	repoID := int64(1)
 	oid1 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f")
 	oid2 := lfsutil.OID("ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64g")
-	err := db.CreateObject(repoID, oid1, 12, lfsutil.StorageLocal)
-	if err != nil {
-		t.Fatal(err)
-	}
-	err = db.CreateObject(repoID, oid2, 12, lfsutil.StorageLocal)
-	if err != nil {
-		t.Fatal(err)
-	}
+	err := db.CreateObject(ctx, repoID, oid1, 12, lfsutil.StorageLocal)
+	require.NoError(t, err)
+	err = db.CreateObject(ctx, repoID, oid2, 12, lfsutil.StorageLocal)
+	require.NoError(t, err)
 
 	// We should be able to get them back and ignore non-existent ones
-	objects, err := db.GetObjectsByOIDs(repoID, oid1, oid2, "bad_oid")
-	if err != nil {
-		t.Fatal(err)
-	}
+	objects, err := db.GetObjectsByOIDs(ctx, repoID, oid1, oid2, "bad_oid")
+	require.NoError(t, err)
 	assert.Equal(t, 2, len(objects), "number of objects")
 
 	assert.Equal(t, repoID, objects[0].RepoID)

+ 1 - 23
internal/db/mock_gen.go

@@ -6,11 +6,9 @@ package db
 
 import (
 	"testing"
-
-	"gogs.io/gogs/internal/lfsutil"
 )
 
-//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i PermsStore -o mocks.go
+//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i PermsStore -o mocks.go
 
 func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
 	before := AccessTokens
@@ -20,26 +18,6 @@ func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
 	})
 }
 
-var _ LFSStore = (*MockLFSStore)(nil)
-
-type MockLFSStore struct {
-	MockCreateObject     func(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error
-	MockGetObjectByOID   func(repoID int64, oid lfsutil.OID) (*LFSObject, error)
-	MockGetObjectsByOIDs func(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error)
-}
-
-func (m *MockLFSStore) CreateObject(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
-	return m.MockCreateObject(repoID, oid, size, storage)
-}
-
-func (m *MockLFSStore) GetObjectByOID(repoID int64, oid lfsutil.OID) (*LFSObject, error) {
-	return m.MockGetObjectByOID(repoID, oid)
-}
-
-func (m *MockLFSStore) GetObjectsByOIDs(repoID int64, oids ...lfsutil.OID) ([]*LFSObject, error) {
-	return m.MockGetObjectsByOIDs(repoID, oids...)
-}
-
 func SetMockLFSStore(t *testing.T, mock LFSStore) {
 	before := LFS
 	LFS = mock

+ 419 - 0
internal/db/mocks.go

@@ -5,6 +5,8 @@ package db
 import (
 	"context"
 	"sync"
+
+	lfsutil "gogs.io/gogs/internal/lfsutil"
 )
 
 // MockAccessTokensStore is a mock implementation of the AccessTokensStore
@@ -658,6 +660,423 @@ func (c AccessTokensStoreTouchFuncCall) Results() []interface{} {
 	return []interface{}{c.Result0}
 }
 
+// MockLFSStore is a mock implementation of the LFSStore interface (from the
+// package gogs.io/gogs/internal/db) used for unit testing.
+type MockLFSStore struct {
+	// CreateObjectFunc is an instance of a mock function object controlling
+	// the behavior of the method CreateObject.
+	CreateObjectFunc *LFSStoreCreateObjectFunc
+	// GetObjectByOIDFunc is an instance of a mock function object
+	// controlling the behavior of the method GetObjectByOID.
+	GetObjectByOIDFunc *LFSStoreGetObjectByOIDFunc
+	// GetObjectsByOIDsFunc is an instance of a mock function object
+	// controlling the behavior of the method GetObjectsByOIDs.
+	GetObjectsByOIDsFunc *LFSStoreGetObjectsByOIDsFunc
+}
+
+// NewMockLFSStore creates a new mock of the LFSStore interface. All methods
+// return zero values for all results, unless overwritten.
+func NewMockLFSStore() *MockLFSStore {
+	return &MockLFSStore{
+		CreateObjectFunc: &LFSStoreCreateObjectFunc{
+			defaultHook: func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) (r0 error) {
+				return
+			},
+		},
+		GetObjectByOIDFunc: &LFSStoreGetObjectByOIDFunc{
+			defaultHook: func(context.Context, int64, lfsutil.OID) (r0 *LFSObject, r1 error) {
+				return
+			},
+		},
+		GetObjectsByOIDsFunc: &LFSStoreGetObjectsByOIDsFunc{
+			defaultHook: func(context.Context, int64, ...lfsutil.OID) (r0 []*LFSObject, r1 error) {
+				return
+			},
+		},
+	}
+}
+
+// NewStrictMockLFSStore creates a new mock of the LFSStore interface. All
+// methods panic on invocation, unless overwritten.
+func NewStrictMockLFSStore() *MockLFSStore {
+	return &MockLFSStore{
+		CreateObjectFunc: &LFSStoreCreateObjectFunc{
+			defaultHook: func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error {
+				panic("unexpected invocation of MockLFSStore.CreateObject")
+			},
+		},
+		GetObjectByOIDFunc: &LFSStoreGetObjectByOIDFunc{
+			defaultHook: func(context.Context, int64, lfsutil.OID) (*LFSObject, error) {
+				panic("unexpected invocation of MockLFSStore.GetObjectByOID")
+			},
+		},
+		GetObjectsByOIDsFunc: &LFSStoreGetObjectsByOIDsFunc{
+			defaultHook: func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) {
+				panic("unexpected invocation of MockLFSStore.GetObjectsByOIDs")
+			},
+		},
+	}
+}
+
+// NewMockLFSStoreFrom creates a new mock of the MockLFSStore interface. All
+// methods delegate to the given implementation, unless overwritten.
+func NewMockLFSStoreFrom(i LFSStore) *MockLFSStore {
+	return &MockLFSStore{
+		CreateObjectFunc: &LFSStoreCreateObjectFunc{
+			defaultHook: i.CreateObject,
+		},
+		GetObjectByOIDFunc: &LFSStoreGetObjectByOIDFunc{
+			defaultHook: i.GetObjectByOID,
+		},
+		GetObjectsByOIDsFunc: &LFSStoreGetObjectsByOIDsFunc{
+			defaultHook: i.GetObjectsByOIDs,
+		},
+	}
+}
+
+// LFSStoreCreateObjectFunc describes the behavior when the CreateObject
+// method of the parent MockLFSStore instance is invoked.
+type LFSStoreCreateObjectFunc struct {
+	defaultHook func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error
+	hooks       []func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error
+	history     []LFSStoreCreateObjectFuncCall
+	mutex       sync.Mutex
+}
+
+// CreateObject delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockLFSStore) CreateObject(v0 context.Context, v1 int64, v2 lfsutil.OID, v3 int64, v4 lfsutil.Storage) error {
+	r0 := m.CreateObjectFunc.nextHook()(v0, v1, v2, v3, v4)
+	m.CreateObjectFunc.appendCall(LFSStoreCreateObjectFuncCall{v0, v1, v2, v3, v4, r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the CreateObject method
+// of the parent MockLFSStore instance is invoked and the hook queue is
+// empty.
+func (f *LFSStoreCreateObjectFunc) SetDefaultHook(hook func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// CreateObject method of the parent MockLFSStore instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *LFSStoreCreateObjectFunc) PushHook(hook func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *LFSStoreCreateObjectFunc) SetDefaultReturn(r0 error) {
+	f.SetDefaultHook(func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *LFSStoreCreateObjectFunc) PushReturn(r0 error) {
+	f.PushHook(func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error {
+		return r0
+	})
+}
+
+func (f *LFSStoreCreateObjectFunc) nextHook() func(context.Context, int64, lfsutil.OID, int64, lfsutil.Storage) error {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *LFSStoreCreateObjectFunc) appendCall(r0 LFSStoreCreateObjectFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of LFSStoreCreateObjectFuncCall objects
+// describing the invocations of this function.
+func (f *LFSStoreCreateObjectFunc) History() []LFSStoreCreateObjectFuncCall {
+	f.mutex.Lock()
+	history := make([]LFSStoreCreateObjectFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// LFSStoreCreateObjectFuncCall is an object that describes an invocation of
+// method CreateObject on an instance of MockLFSStore.
+type LFSStoreCreateObjectFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 context.Context
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 int64
+	// Arg2 is the value of the 3rd argument passed to this method
+	// invocation.
+	Arg2 lfsutil.OID
+	// Arg3 is the value of the 4th argument passed to this method
+	// invocation.
+	Arg3 int64
+	// Arg4 is the value of the 5th argument passed to this method
+	// invocation.
+	Arg4 lfsutil.Storage
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c LFSStoreCreateObjectFuncCall) Args() []interface{} {
+	return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3, c.Arg4}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c LFSStoreCreateObjectFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}
+
+// LFSStoreGetObjectByOIDFunc describes the behavior when the GetObjectByOID
+// method of the parent MockLFSStore instance is invoked.
+type LFSStoreGetObjectByOIDFunc struct {
+	defaultHook func(context.Context, int64, lfsutil.OID) (*LFSObject, error)
+	hooks       []func(context.Context, int64, lfsutil.OID) (*LFSObject, error)
+	history     []LFSStoreGetObjectByOIDFuncCall
+	mutex       sync.Mutex
+}
+
+// GetObjectByOID delegates to the next hook function in the queue and
+// stores the parameter and result values of this invocation.
+func (m *MockLFSStore) GetObjectByOID(v0 context.Context, v1 int64, v2 lfsutil.OID) (*LFSObject, error) {
+	r0, r1 := m.GetObjectByOIDFunc.nextHook()(v0, v1, v2)
+	m.GetObjectByOIDFunc.appendCall(LFSStoreGetObjectByOIDFuncCall{v0, v1, v2, r0, r1})
+	return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetObjectByOID
+// method of the parent MockLFSStore instance is invoked and the hook queue
+// is empty.
+func (f *LFSStoreGetObjectByOIDFunc) SetDefaultHook(hook func(context.Context, int64, lfsutil.OID) (*LFSObject, error)) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetObjectByOID method of the parent MockLFSStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *LFSStoreGetObjectByOIDFunc) PushHook(hook func(context.Context, int64, lfsutil.OID) (*LFSObject, error)) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *LFSStoreGetObjectByOIDFunc) SetDefaultReturn(r0 *LFSObject, r1 error) {
+	f.SetDefaultHook(func(context.Context, int64, lfsutil.OID) (*LFSObject, error) {
+		return r0, r1
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *LFSStoreGetObjectByOIDFunc) PushReturn(r0 *LFSObject, r1 error) {
+	f.PushHook(func(context.Context, int64, lfsutil.OID) (*LFSObject, error) {
+		return r0, r1
+	})
+}
+
+func (f *LFSStoreGetObjectByOIDFunc) nextHook() func(context.Context, int64, lfsutil.OID) (*LFSObject, error) {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *LFSStoreGetObjectByOIDFunc) appendCall(r0 LFSStoreGetObjectByOIDFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of LFSStoreGetObjectByOIDFuncCall objects
+// describing the invocations of this function.
+func (f *LFSStoreGetObjectByOIDFunc) History() []LFSStoreGetObjectByOIDFuncCall {
+	f.mutex.Lock()
+	history := make([]LFSStoreGetObjectByOIDFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// LFSStoreGetObjectByOIDFuncCall is an object that describes an invocation
+// of method GetObjectByOID on an instance of MockLFSStore.
+type LFSStoreGetObjectByOIDFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 context.Context
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 int64
+	// Arg2 is the value of the 3rd argument passed to this method
+	// invocation.
+	Arg2 lfsutil.OID
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 *LFSObject
+	// Result1 is the value of the 2nd result returned from this method
+	// invocation.
+	Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c LFSStoreGetObjectByOIDFuncCall) Args() []interface{} {
+	return []interface{}{c.Arg0, c.Arg1, c.Arg2}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c LFSStoreGetObjectByOIDFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0, c.Result1}
+}
+
+// LFSStoreGetObjectsByOIDsFunc describes the behavior when the
+// GetObjectsByOIDs method of the parent MockLFSStore instance is invoked.
+type LFSStoreGetObjectsByOIDsFunc struct {
+	defaultHook func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error)
+	hooks       []func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error)
+	history     []LFSStoreGetObjectsByOIDsFuncCall
+	mutex       sync.Mutex
+}
+
+// GetObjectsByOIDs delegates to the next hook function in the queue and
+// stores the parameter and result values of this invocation.
+func (m *MockLFSStore) GetObjectsByOIDs(v0 context.Context, v1 int64, v2 ...lfsutil.OID) ([]*LFSObject, error) {
+	r0, r1 := m.GetObjectsByOIDsFunc.nextHook()(v0, v1, v2...)
+	m.GetObjectsByOIDsFunc.appendCall(LFSStoreGetObjectsByOIDsFuncCall{v0, v1, v2, r0, r1})
+	return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetObjectsByOIDs
+// method of the parent MockLFSStore instance is invoked and the hook queue
+// is empty.
+func (f *LFSStoreGetObjectsByOIDsFunc) SetDefaultHook(hook func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error)) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetObjectsByOIDs method of the parent MockLFSStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *LFSStoreGetObjectsByOIDsFunc) PushHook(hook func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error)) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *LFSStoreGetObjectsByOIDsFunc) SetDefaultReturn(r0 []*LFSObject, r1 error) {
+	f.SetDefaultHook(func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) {
+		return r0, r1
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *LFSStoreGetObjectsByOIDsFunc) PushReturn(r0 []*LFSObject, r1 error) {
+	f.PushHook(func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) {
+		return r0, r1
+	})
+}
+
+func (f *LFSStoreGetObjectsByOIDsFunc) nextHook() func(context.Context, int64, ...lfsutil.OID) ([]*LFSObject, error) {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *LFSStoreGetObjectsByOIDsFunc) appendCall(r0 LFSStoreGetObjectsByOIDsFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of LFSStoreGetObjectsByOIDsFuncCall objects
+// describing the invocations of this function.
+func (f *LFSStoreGetObjectsByOIDsFunc) History() []LFSStoreGetObjectsByOIDsFuncCall {
+	f.mutex.Lock()
+	history := make([]LFSStoreGetObjectsByOIDsFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// LFSStoreGetObjectsByOIDsFuncCall is an object that describes an
+// invocation of method GetObjectsByOIDs on an instance of MockLFSStore.
+type LFSStoreGetObjectsByOIDsFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 context.Context
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 int64
+	// Arg2 is a slice containing the values of the variadic arguments
+	// passed to this method invocation.
+	Arg2 []lfsutil.OID
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 []*LFSObject
+	// Result1 is the value of the 2nd result returned from this method
+	// invocation.
+	Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation. The variadic slice argument is flattened in this array such
+// that one positional argument and three variadic arguments would result in
+// a slice of four, not two.
+func (c LFSStoreGetObjectsByOIDsFuncCall) Args() []interface{} {
+	trailing := []interface{}{}
+	for _, val := range c.Arg2 {
+		trailing = append(trailing, val)
+	}
+
+	return append([]interface{}{c.Arg0, c.Arg1}, trailing...)
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c LFSStoreGetObjectsByOIDsFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0, c.Result1}
+}
+
 // MockPermsStore is a mock implementation of the PermsStore interface (from
 // the package gogs.io/gogs/internal/db) used for unit testing.
 type MockPermsStore struct {

+ 4 - 4
internal/route/lfs/basic.go

@@ -44,7 +44,7 @@ func (h *basicHandler) Storager(storage lfsutil.Storage) lfsutil.Storager {
 
 // GET /{owner}/{repo}.git/info/lfs/object/basic/{oid}
 func (h *basicHandler) serveDownload(c *macaron.Context, repo *db.Repository, oid lfsutil.OID) {
-	object, err := db.LFS.GetObjectByOID(repo.ID, oid)
+	object, err := db.LFS.GetObjectByOID(c.Req.Context(), repo.ID, oid)
 	if err != nil {
 		if db.IsErrLFSObjectNotExist(err) {
 			responseJSON(c.Resp, http.StatusNotFound, responseError{
@@ -79,7 +79,7 @@ func (h *basicHandler) serveDownload(c *macaron.Context, repo *db.Repository, oi
 func (h *basicHandler) serveUpload(c *macaron.Context, repo *db.Repository, oid lfsutil.OID) {
 	// NOTE: LFS client will retry upload the same object if there was a partial failure,
 	// therefore we would like to skip ones that already exist.
-	_, err := db.LFS.GetObjectByOID(repo.ID, oid)
+	_, err := db.LFS.GetObjectByOID(c.Req.Context(), repo.ID, oid)
 	if err == nil {
 		// Object exists, drain the request body and we're good.
 		_, _ = io.Copy(ioutil.Discard, c.Req.Request.Body)
@@ -106,7 +106,7 @@ func (h *basicHandler) serveUpload(c *macaron.Context, repo *db.Repository, oid
 		return
 	}
 
-	err = db.LFS.CreateObject(repo.ID, oid, written, s.Storage())
+	err = db.LFS.CreateObject(c.Req.Context(), repo.ID, oid, written, s.Storage())
 	if err != nil {
 		// NOTE: It is OK to leave the file when the whole operation failed
 		// with a DB error, a retry on client side can safely overwrite the
@@ -139,7 +139,7 @@ func (*basicHandler) serveVerify(c *macaron.Context, repo *db.Repository) {
 		return
 	}
 
-	object, err := db.LFS.GetObjectByOID(repo.ID, request.Oid)
+	object, err := db.LFS.GetObjectByOID(c.Req.Context(), repo.ID, request.Oid)
 	if err != nil {
 		if db.IsErrLFSObjectNotExist(err) {
 			responseJSON(c.Resp, http.StatusNotFound, responseError{

+ 44 - 42
internal/route/lfs/basic_test.go

@@ -61,17 +61,17 @@ func Test_basicHandler_serveDownload(t *testing.T) {
 	tests := []struct {
 		name          string
 		content       string
-		mockLFSStore  *db.MockLFSStore
+		mockLFSStore  func() db.LFSStore
 		expStatusCode int
 		expHeader     http.Header
 		expBody       string
 	}{
 		{
 			name: "object does not exist",
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return nil, db.ErrLFSObjectNotExist{}
-				},
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(nil, db.ErrLFSObjectNotExist{})
+				return mock
 			},
 			expStatusCode: http.StatusNotFound,
 			expHeader: http.Header{
@@ -81,10 +81,10 @@ func Test_basicHandler_serveDownload(t *testing.T) {
 		},
 		{
 			name: "storage not found",
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return &db.LFSObject{Storage: "bad_storage"}, nil
-				},
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{Storage: "bad_storage"}, nil)
+				return mock
 			},
 			expStatusCode: http.StatusInternalServerError,
 			expHeader: http.Header{
@@ -96,13 +96,16 @@ func Test_basicHandler_serveDownload(t *testing.T) {
 		{
 			name:    "object exists",
 			content: "Hello world!",
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return &db.LFSObject{
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(
+					&db.LFSObject{
 						Size:    12,
 						Storage: s.Storage(),
-					}, nil
-				},
+					},
+					nil,
+				)
+				return mock
 			},
 			expStatusCode: http.StatusOK,
 			expHeader: http.Header{
@@ -114,7 +117,7 @@ func Test_basicHandler_serveDownload(t *testing.T) {
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			db.SetMockLFSStore(t, test.mockLFSStore)
+			db.SetMockLFSStore(t, test.mockLFSStore())
 
 			s.buf = bytes.NewBufferString(test.content)
 
@@ -158,35 +161,32 @@ func Test_basicHandler_serveUpload(t *testing.T) {
 
 	tests := []struct {
 		name          string
-		mockLFSStore  *db.MockLFSStore
+		mockLFSStore  func() db.LFSStore
 		expStatusCode int
 		expBody       string
 	}{
 		{
 			name: "object already exists",
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return &db.LFSObject{}, nil
-				},
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{}, nil)
+				return mock
 			},
 			expStatusCode: http.StatusOK,
 		},
 		{
 			name: "new object",
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return nil, db.ErrLFSObjectNotExist{}
-				},
-				MockCreateObject: func(repoID int64, oid lfsutil.OID, size int64, storage lfsutil.Storage) error {
-					return nil
-				},
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(nil, db.ErrLFSObjectNotExist{})
+				return mock
 			},
 			expStatusCode: http.StatusOK,
 		},
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			db.SetMockLFSStore(t, test.mockLFSStore)
+			db.SetMockLFSStore(t, test.mockLFSStore())
 
 			r, err := http.NewRequest("PUT", "/", strings.NewReader("Hello world!"))
 			if err != nil {
@@ -219,7 +219,7 @@ func Test_basicHandler_serveVerify(t *testing.T) {
 	tests := []struct {
 		name          string
 		body          string
-		mockLFSStore  *db.MockLFSStore
+		mockLFSStore  func() db.LFSStore
 		expStatusCode int
 		expBody       string
 	}{
@@ -232,10 +232,10 @@ func Test_basicHandler_serveVerify(t *testing.T) {
 		{
 			name: "object does not exist",
 			body: `{"oid":"ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f"}`,
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return nil, db.ErrLFSObjectNotExist{}
-				},
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(nil, db.ErrLFSObjectNotExist{})
+				return mock
 			},
 			expStatusCode: http.StatusNotFound,
 			expBody:       `{"message":"Object does not exist"}` + "\n",
@@ -243,10 +243,10 @@ func Test_basicHandler_serveVerify(t *testing.T) {
 		{
 			name: "object size mismatch",
 			body: `{"oid":"ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f"}`,
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return &db.LFSObject{Size: 12}, nil
-				},
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{Size: 12}, nil)
+				return mock
 			},
 			expStatusCode: http.StatusBadRequest,
 			expBody:       `{"message":"Object size mismatch"}` + "\n",
@@ -255,17 +255,19 @@ func Test_basicHandler_serveVerify(t *testing.T) {
 		{
 			name: "object exists",
 			body: `{"oid":"ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f", "size":12}`,
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectByOID: func(repoID int64, oid lfsutil.OID) (*db.LFSObject, error) {
-					return &db.LFSObject{Size: 12}, nil
-				},
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectByOIDFunc.SetDefaultReturn(&db.LFSObject{Size: 12}, nil)
+				return mock
 			},
 			expStatusCode: http.StatusOK,
 		},
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			db.SetMockLFSStore(t, test.mockLFSStore)
+			if test.mockLFSStore != nil {
+				db.SetMockLFSStore(t, test.mockLFSStore())
+			}
 
 			r, err := http.NewRequest("POST", "/", strings.NewReader(test.body))
 			if err != nil {

+ 1 - 1
internal/route/lfs/batch.go

@@ -75,7 +75,7 @@ func serveBatch(c *macaron.Context, owner *db.User, repo *db.Repository) {
 		for _, obj := range request.Objects {
 			oids = append(oids, obj.Oid)
 		}
-		stored, err := db.LFS.GetObjectsByOIDs(repo.ID, oids...)
+		stored, err := db.LFS.GetObjectsByOIDs(c.Req.Context(), repo.ID, oids...)
 		if err != nil {
 			internalServerError(c.Resp)
 			log.Error("Failed to get objects [repo_id: %d, oids: %v]: %v", repo.ID, oids, err)

+ 12 - 8
internal/route/lfs/batch_test.go

@@ -17,7 +17,6 @@ import (
 
 	"gogs.io/gogs/internal/conf"
 	"gogs.io/gogs/internal/db"
-	"gogs.io/gogs/internal/lfsutil"
 )
 
 func Test_serveBatch(t *testing.T) {
@@ -35,7 +34,7 @@ func Test_serveBatch(t *testing.T) {
 	tests := []struct {
 		name          string
 		body          string
-		mockLFSStore  *db.MockLFSStore
+		mockLFSStore  func() db.LFSStore
 		expStatusCode int
 		expBody       string
 	}{
@@ -83,9 +82,10 @@ func Test_serveBatch(t *testing.T) {
 	{"oid": "ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f", "size": 123},
 	{"oid": "5cac0a318669fadfee734fb340a5f5b70b428ac57a9f4b109cb6e150b2ba7e57", "size": 456}
 ]}`,
-			mockLFSStore: &db.MockLFSStore{
-				MockGetObjectsByOIDs: func(repoID int64, oids ...lfsutil.OID) ([]*db.LFSObject, error) {
-					return []*db.LFSObject{
+			mockLFSStore: func() db.LFSStore {
+				mock := db.NewMockLFSStore()
+				mock.GetObjectsByOIDsFunc.SetDefaultReturn(
+					[]*db.LFSObject{
 						{
 							OID:  "ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f",
 							Size: 1234,
@@ -93,8 +93,10 @@ func Test_serveBatch(t *testing.T) {
 							OID:  "5cac0a318669fadfee734fb340a5f5b70b428ac57a9f4b109cb6e150b2ba7e57",
 							Size: 456,
 						},
-					}, nil
-				},
+					},
+					nil,
+				)
+				return mock
 			},
 			expStatusCode: http.StatusOK,
 			expBody: `{
@@ -121,7 +123,9 @@ func Test_serveBatch(t *testing.T) {
 	}
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			db.SetMockLFSStore(t, test.mockLFSStore)
+			if test.mockLFSStore != nil {
+				db.SetMockLFSStore(t, test.mockLFSStore())
+			}
 
 			r, err := http.NewRequest("POST", "/", bytes.NewBufferString(test.body))
 			if err != nil {