Browse Source

db: use `context` and go-mockgen for `TwoFactorsStore` (#7045)

Joe Chen 2 years ago
parent
commit
5e32058c13

+ 1 - 21
internal/db/mock_gen.go

@@ -8,7 +8,7 @@ import (
 	"testing"
 )
 
-//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i UsersStore -o mocks.go
+//go:generate go-mockgen -f gogs.io/gogs/internal/db -i AccessTokensStore -i LFSStore -i LoginSourcesStore -i LoginSourceFilesStore -i loginSourceFileStore -i PermsStore -i TwoFactorsStore -i UsersStore -o mocks.go
 
 func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
 	before := AccessTokens
@@ -60,26 +60,6 @@ func SetMockReposStore(t *testing.T, mock ReposStore) {
 	})
 }
 
-var _ TwoFactorsStore = (*MockTwoFactorsStore)(nil)
-
-type MockTwoFactorsStore struct {
-	MockCreate        func(userID int64, key, secret string) error
-	MockGetByUserID   func(userID int64) (*TwoFactor, error)
-	MockIsUserEnabled func(userID int64) bool
-}
-
-func (m *MockTwoFactorsStore) Create(userID int64, key, secret string) error {
-	return m.MockCreate(userID, key, secret)
-}
-
-func (m *MockTwoFactorsStore) GetByUserID(userID int64) (*TwoFactor, error) {
-	return m.MockGetByUserID(userID)
-}
-
-func (m *MockTwoFactorsStore) IsUserEnabled(userID int64) bool {
-	return m.MockIsUserEnabled(userID)
-}
-
 func SetMockTwoFactorsStore(t *testing.T, mock TwoFactorsStore) {
 	before := TwoFactors
 	TwoFactors = mock

+ 402 - 0
internal/db/mocks.go

@@ -2371,6 +2371,408 @@ func (c PermsStoreSetRepoPermsFuncCall) Results() []interface{} {
 	return []interface{}{c.Result0}
 }
 
+// MockTwoFactorsStore is a mock implementation of the TwoFactorsStore
+// interface (from the package gogs.io/gogs/internal/db) used for unit
+// testing.
+type MockTwoFactorsStore struct {
+	// CreateFunc is an instance of a mock function object controlling the
+	// behavior of the method Create.
+	CreateFunc *TwoFactorsStoreCreateFunc
+	// GetByUserIDFunc is an instance of a mock function object controlling
+	// the behavior of the method GetByUserID.
+	GetByUserIDFunc *TwoFactorsStoreGetByUserIDFunc
+	// IsUserEnabledFunc is an instance of a mock function object
+	// controlling the behavior of the method IsUserEnabled.
+	IsUserEnabledFunc *TwoFactorsStoreIsUserEnabledFunc
+}
+
+// NewMockTwoFactorsStore creates a new mock of the TwoFactorsStore
+// interface. All methods return zero values for all results, unless
+// overwritten.
+func NewMockTwoFactorsStore() *MockTwoFactorsStore {
+	return &MockTwoFactorsStore{
+		CreateFunc: &TwoFactorsStoreCreateFunc{
+			defaultHook: func(context.Context, int64, string, string) (r0 error) {
+				return
+			},
+		},
+		GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{
+			defaultHook: func(context.Context, int64) (r0 *TwoFactor, r1 error) {
+				return
+			},
+		},
+		IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{
+			defaultHook: func(context.Context, int64) (r0 bool) {
+				return
+			},
+		},
+	}
+}
+
+// NewStrictMockTwoFactorsStore creates a new mock of the TwoFactorsStore
+// interface. All methods panic on invocation, unless overwritten.
+func NewStrictMockTwoFactorsStore() *MockTwoFactorsStore {
+	return &MockTwoFactorsStore{
+		CreateFunc: &TwoFactorsStoreCreateFunc{
+			defaultHook: func(context.Context, int64, string, string) error {
+				panic("unexpected invocation of MockTwoFactorsStore.Create")
+			},
+		},
+		GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{
+			defaultHook: func(context.Context, int64) (*TwoFactor, error) {
+				panic("unexpected invocation of MockTwoFactorsStore.GetByUserID")
+			},
+		},
+		IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{
+			defaultHook: func(context.Context, int64) bool {
+				panic("unexpected invocation of MockTwoFactorsStore.IsUserEnabled")
+			},
+		},
+	}
+}
+
+// NewMockTwoFactorsStoreFrom creates a new mock of the MockTwoFactorsStore
+// interface. All methods delegate to the given implementation, unless
+// overwritten.
+func NewMockTwoFactorsStoreFrom(i TwoFactorsStore) *MockTwoFactorsStore {
+	return &MockTwoFactorsStore{
+		CreateFunc: &TwoFactorsStoreCreateFunc{
+			defaultHook: i.Create,
+		},
+		GetByUserIDFunc: &TwoFactorsStoreGetByUserIDFunc{
+			defaultHook: i.GetByUserID,
+		},
+		IsUserEnabledFunc: &TwoFactorsStoreIsUserEnabledFunc{
+			defaultHook: i.IsUserEnabled,
+		},
+	}
+}
+
+// TwoFactorsStoreCreateFunc describes the behavior when the Create method
+// of the parent MockTwoFactorsStore instance is invoked.
+type TwoFactorsStoreCreateFunc struct {
+	defaultHook func(context.Context, int64, string, string) error
+	hooks       []func(context.Context, int64, string, string) error
+	history     []TwoFactorsStoreCreateFuncCall
+	mutex       sync.Mutex
+}
+
+// Create delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockTwoFactorsStore) Create(v0 context.Context, v1 int64, v2 string, v3 string) error {
+	r0 := m.CreateFunc.nextHook()(v0, v1, v2, v3)
+	m.CreateFunc.appendCall(TwoFactorsStoreCreateFuncCall{v0, v1, v2, v3, r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the Create method of the
+// parent MockTwoFactorsStore instance is invoked and the hook queue is
+// empty.
+func (f *TwoFactorsStoreCreateFunc) SetDefaultHook(hook func(context.Context, int64, string, string) error) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// Create method of the parent MockTwoFactorsStore instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *TwoFactorsStoreCreateFunc) PushHook(hook func(context.Context, int64, string, string) error) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *TwoFactorsStoreCreateFunc) SetDefaultReturn(r0 error) {
+	f.SetDefaultHook(func(context.Context, int64, string, string) error {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *TwoFactorsStoreCreateFunc) PushReturn(r0 error) {
+	f.PushHook(func(context.Context, int64, string, string) error {
+		return r0
+	})
+}
+
+func (f *TwoFactorsStoreCreateFunc) nextHook() func(context.Context, int64, string, string) error {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *TwoFactorsStoreCreateFunc) appendCall(r0 TwoFactorsStoreCreateFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of TwoFactorsStoreCreateFuncCall objects
+// describing the invocations of this function.
+func (f *TwoFactorsStoreCreateFunc) History() []TwoFactorsStoreCreateFuncCall {
+	f.mutex.Lock()
+	history := make([]TwoFactorsStoreCreateFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// TwoFactorsStoreCreateFuncCall is an object that describes an invocation
+// of method Create on an instance of MockTwoFactorsStore.
+type TwoFactorsStoreCreateFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 context.Context
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 int64
+	// Arg2 is the value of the 3rd argument passed to this method
+	// invocation.
+	Arg2 string
+	// Arg3 is the value of the 4th argument passed to this method
+	// invocation.
+	Arg3 string
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c TwoFactorsStoreCreateFuncCall) Args() []interface{} {
+	return []interface{}{c.Arg0, c.Arg1, c.Arg2, c.Arg3}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c TwoFactorsStoreCreateFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}
+
+// TwoFactorsStoreGetByUserIDFunc describes the behavior when the
+// GetByUserID method of the parent MockTwoFactorsStore instance is invoked.
+type TwoFactorsStoreGetByUserIDFunc struct {
+	defaultHook func(context.Context, int64) (*TwoFactor, error)
+	hooks       []func(context.Context, int64) (*TwoFactor, error)
+	history     []TwoFactorsStoreGetByUserIDFuncCall
+	mutex       sync.Mutex
+}
+
+// GetByUserID delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockTwoFactorsStore) GetByUserID(v0 context.Context, v1 int64) (*TwoFactor, error) {
+	r0, r1 := m.GetByUserIDFunc.nextHook()(v0, v1)
+	m.GetByUserIDFunc.appendCall(TwoFactorsStoreGetByUserIDFuncCall{v0, v1, r0, r1})
+	return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the GetByUserID method
+// of the parent MockTwoFactorsStore instance is invoked and the hook queue
+// is empty.
+func (f *TwoFactorsStoreGetByUserIDFunc) SetDefaultHook(hook func(context.Context, int64) (*TwoFactor, error)) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// GetByUserID method of the parent MockTwoFactorsStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *TwoFactorsStoreGetByUserIDFunc) PushHook(hook func(context.Context, int64) (*TwoFactor, error)) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *TwoFactorsStoreGetByUserIDFunc) SetDefaultReturn(r0 *TwoFactor, r1 error) {
+	f.SetDefaultHook(func(context.Context, int64) (*TwoFactor, error) {
+		return r0, r1
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *TwoFactorsStoreGetByUserIDFunc) PushReturn(r0 *TwoFactor, r1 error) {
+	f.PushHook(func(context.Context, int64) (*TwoFactor, error) {
+		return r0, r1
+	})
+}
+
+func (f *TwoFactorsStoreGetByUserIDFunc) nextHook() func(context.Context, int64) (*TwoFactor, error) {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *TwoFactorsStoreGetByUserIDFunc) appendCall(r0 TwoFactorsStoreGetByUserIDFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of TwoFactorsStoreGetByUserIDFuncCall objects
+// describing the invocations of this function.
+func (f *TwoFactorsStoreGetByUserIDFunc) History() []TwoFactorsStoreGetByUserIDFuncCall {
+	f.mutex.Lock()
+	history := make([]TwoFactorsStoreGetByUserIDFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// TwoFactorsStoreGetByUserIDFuncCall is an object that describes an
+// invocation of method GetByUserID on an instance of MockTwoFactorsStore.
+type TwoFactorsStoreGetByUserIDFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 context.Context
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 int64
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 *TwoFactor
+	// Result1 is the value of the 2nd result returned from this method
+	// invocation.
+	Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c TwoFactorsStoreGetByUserIDFuncCall) Args() []interface{} {
+	return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c TwoFactorsStoreGetByUserIDFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0, c.Result1}
+}
+
+// TwoFactorsStoreIsUserEnabledFunc describes the behavior when the
+// IsUserEnabled method of the parent MockTwoFactorsStore instance is
+// invoked.
+type TwoFactorsStoreIsUserEnabledFunc struct {
+	defaultHook func(context.Context, int64) bool
+	hooks       []func(context.Context, int64) bool
+	history     []TwoFactorsStoreIsUserEnabledFuncCall
+	mutex       sync.Mutex
+}
+
+// IsUserEnabled delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockTwoFactorsStore) IsUserEnabled(v0 context.Context, v1 int64) bool {
+	r0 := m.IsUserEnabledFunc.nextHook()(v0, v1)
+	m.IsUserEnabledFunc.appendCall(TwoFactorsStoreIsUserEnabledFuncCall{v0, v1, r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the IsUserEnabled method
+// of the parent MockTwoFactorsStore instance is invoked and the hook queue
+// is empty.
+func (f *TwoFactorsStoreIsUserEnabledFunc) SetDefaultHook(hook func(context.Context, int64) bool) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// IsUserEnabled method of the parent MockTwoFactorsStore instance invokes
+// the hook at the front of the queue and discards it. After the queue is
+// empty, the default hook function is invoked for any future action.
+func (f *TwoFactorsStoreIsUserEnabledFunc) PushHook(hook func(context.Context, int64) bool) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *TwoFactorsStoreIsUserEnabledFunc) SetDefaultReturn(r0 bool) {
+	f.SetDefaultHook(func(context.Context, int64) bool {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *TwoFactorsStoreIsUserEnabledFunc) PushReturn(r0 bool) {
+	f.PushHook(func(context.Context, int64) bool {
+		return r0
+	})
+}
+
+func (f *TwoFactorsStoreIsUserEnabledFunc) nextHook() func(context.Context, int64) bool {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *TwoFactorsStoreIsUserEnabledFunc) appendCall(r0 TwoFactorsStoreIsUserEnabledFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of TwoFactorsStoreIsUserEnabledFuncCall
+// objects describing the invocations of this function.
+func (f *TwoFactorsStoreIsUserEnabledFunc) History() []TwoFactorsStoreIsUserEnabledFuncCall {
+	f.mutex.Lock()
+	history := make([]TwoFactorsStoreIsUserEnabledFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// TwoFactorsStoreIsUserEnabledFuncCall is an object that describes an
+// invocation of method IsUserEnabled on an instance of MockTwoFactorsStore.
+type TwoFactorsStoreIsUserEnabledFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 context.Context
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 int64
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 bool
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c TwoFactorsStoreIsUserEnabledFuncCall) Args() []interface{} {
+	return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c TwoFactorsStoreIsUserEnabledFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}
+
 // MockUsersStore is a mock implementation of the UsersStore interface (from
 // the package gogs.io/gogs/internal/db) used for unit testing.
 type MockUsersStore struct {

+ 18 - 17
internal/db/two_factors.go

@@ -5,6 +5,7 @@
 package db
 
 import (
+	"context"
 	"encoding/base64"
 	"fmt"
 	"strings"
@@ -23,21 +24,21 @@ import (
 //
 // NOTE: All methods are sorted in alphabetical order.
 type TwoFactorsStore interface {
-	// Create creates a new 2FA token and recovery codes for given user.
-	// The "key" is used to encrypt and later decrypt given "secret",
-	// which should be configured in site-level and change of the "key"
-	// will break all existing 2FA tokens.
-	Create(userID int64, key, secret string) error
-	// GetByUserID returns the 2FA token of given user.
-	// It returns ErrTwoFactorNotFound when not found.
-	GetByUserID(userID int64) (*TwoFactor, error)
+	// Create creates a new 2FA token and recovery codes for given user. The "key"
+	// is used to encrypt and later decrypt given "secret", which should be
+	// configured in site-level and change of the "key" will break all existing 2FA
+	// tokens.
+	Create(ctx context.Context, userID int64, key, secret string) error
+	// GetByUserID returns the 2FA token of given user. It returns
+	// ErrTwoFactorNotFound when not found.
+	GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error)
 	// IsUserEnabled returns true if the user has enabled 2FA.
-	IsUserEnabled(userID int64) bool
+	IsUserEnabled(ctx context.Context, userID int64) bool
 }
 
 var TwoFactors TwoFactorsStore
 
-// NOTE: This is a GORM create hook.
+// BeforeCreate implements the GORM create hook.
 func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error {
 	if t.CreatedUnix == 0 {
 		t.CreatedUnix = tx.NowFunc().Unix()
@@ -45,7 +46,7 @@ func (t *TwoFactor) BeforeCreate(tx *gorm.DB) error {
 	return nil
 }
 
-// NOTE: This is a GORM query hook.
+// AfterFind implements the GORM query hook.
 func (t *TwoFactor) AfterFind(_ *gorm.DB) error {
 	t.Created = time.Unix(t.CreatedUnix, 0).Local()
 	return nil
@@ -57,7 +58,7 @@ type twoFactors struct {
 	*gorm.DB
 }
 
-func (db *twoFactors) Create(userID int64, key, secret string) error {
+func (db *twoFactors) Create(ctx context.Context, userID int64, key, secret string) error {
 	encrypted, err := cryptoutil.AESGCMEncrypt(cryptoutil.MD5Bytes(key), []byte(secret))
 	if err != nil {
 		return errors.Wrap(err, "encrypt secret")
@@ -72,7 +73,7 @@ func (db *twoFactors) Create(userID int64, key, secret string) error {
 		return errors.Wrap(err, "generate recovery codes")
 	}
 
-	return db.Transaction(func(tx *gorm.DB) error {
+	return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
 		err := tx.Create(tf).Error
 		if err != nil {
 			return err
@@ -101,9 +102,9 @@ func (ErrTwoFactorNotFound) NotFound() bool {
 	return true
 }
 
-func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) {
+func (db *twoFactors) GetByUserID(ctx context.Context, userID int64) (*TwoFactor, error) {
 	tf := new(TwoFactor)
-	err := db.Where("user_id = ?", userID).First(tf).Error
+	err := db.WithContext(ctx).Where("user_id = ?", userID).First(tf).Error
 	if err != nil {
 		if err == gorm.ErrRecordNotFound {
 			return nil, ErrTwoFactorNotFound{args: errutil.Args{"userID": userID}}
@@ -113,9 +114,9 @@ func (db *twoFactors) GetByUserID(userID int64) (*TwoFactor, error) {
 	return tf, nil
 }
 
-func (db *twoFactors) IsUserEnabled(userID int64) bool {
+func (db *twoFactors) IsUserEnabled(ctx context.Context, userID int64) bool {
 	var count int64
-	err := db.Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
+	err := db.WithContext(ctx).Model(new(TwoFactor)).Where("user_id = ?", userID).Count(&count).Error
 	if err != nil {
 		log.Error("Failed to count two factors [user_id: %d]: %v", userID, err)
 	}

+ 32 - 38
internal/db/two_factors_test.go

@@ -5,15 +5,17 @@
 package db
 
 import (
+	"context"
 	"testing"
 	"time"
 
 	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 
 	"gogs.io/gogs/internal/errutil"
 )
 
-func Test_twoFactors(t *testing.T) {
+func TestTwoFactors(t *testing.T) {
 	if testing.Short() {
 		t.Skip()
 	}
@@ -29,16 +31,14 @@ func Test_twoFactors(t *testing.T) {
 		name string
 		test func(*testing.T, *twoFactors)
 	}{
-		{"Create", test_twoFactors_Create},
-		{"GetByUserID", test_twoFactors_GetByUserID},
-		{"IsUserEnabled", test_twoFactors_IsUserEnabled},
+		{"Create", twoFactorsCreate},
+		{"GetByUserID", twoFactorsGetByUserID},
+		{"IsUserEnabled", twoFactorsIsUserEnabled},
 	} {
 		t.Run(tc.name, func(t *testing.T) {
 			t.Cleanup(func() {
 				err := clearTables(t, db.DB, tables...)
-				if err != nil {
-					t.Fatal(err)
-				}
+				require.NoError(t, err)
 			})
 			tc.test(t, db)
 		})
@@ -48,55 +48,49 @@ func Test_twoFactors(t *testing.T) {
 	}
 }
 
-func test_twoFactors_Create(t *testing.T, db *twoFactors) {
+func twoFactorsCreate(t *testing.T, db *twoFactors) {
+	ctx := context.Background()
+
 	// Create a 2FA token
-	err := db.Create(1, "secure-key", "secure-secret")
-	if err != nil {
-		t.Fatal(err)
-	}
+	err := db.Create(ctx, 1, "secure-key", "secure-secret")
+	require.NoError(t, err)
 
 	// Get it back and check the Created field
-	tf, err := db.GetByUserID(1)
-	if err != nil {
-		t.Fatal(err)
-	}
+	tf, err := db.GetByUserID(ctx, 1)
+	require.NoError(t, err)
 	assert.Equal(t, db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339))
 
 	// Verify there are 10 recover codes generated
 	var count int64
 	err = db.Model(new(TwoFactorRecoveryCode)).Count(&count).Error
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.NoError(t, err)
 	assert.Equal(t, int64(10), count)
 }
 
-func test_twoFactors_GetByUserID(t *testing.T, db *twoFactors) {
+func twoFactorsGetByUserID(t *testing.T, db *twoFactors) {
+	ctx := context.Background()
+
 	// Create a 2FA token for user 1
-	err := db.Create(1, "secure-key", "secure-secret")
-	if err != nil {
-		t.Fatal(err)
-	}
+	err := db.Create(ctx, 1, "secure-key", "secure-secret")
+	require.NoError(t, err)
 
 	// We should be able to get it back
-	_, err = db.GetByUserID(1)
-	if err != nil {
-		t.Fatal(err)
-	}
+	_, err = db.GetByUserID(ctx, 1)
+	require.NoError(t, err)
 
 	// Try to get a non-existent 2FA token
-	_, err = db.GetByUserID(2)
-	expErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}}
-	assert.Equal(t, expErr, err)
+	_, err = db.GetByUserID(ctx, 2)
+	wantErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}}
+	assert.Equal(t, wantErr, err)
 }
 
-func test_twoFactors_IsUserEnabled(t *testing.T, db *twoFactors) {
+func twoFactorsIsUserEnabled(t *testing.T, db *twoFactors) {
+	ctx := context.Background()
+
 	// Create a 2FA token for user 1
-	err := db.Create(1, "secure-key", "secure-secret")
-	if err != nil {
-		t.Fatal(err)
-	}
+	err := db.Create(ctx, 1, "secure-key", "secure-secret")
+	require.NoError(t, err)
 
-	assert.True(t, db.IsUserEnabled(1))
-	assert.False(t, db.IsUserEnabled(2))
+	assert.True(t, db.IsUserEnabled(ctx, 1))
+	assert.False(t, db.IsUserEnabled(ctx, 2))
 }

+ 1 - 1
internal/db/user.go

@@ -405,7 +405,7 @@ func (u *User) IsPublicMember(orgId int64) bool {
 
 // IsEnabledTwoFactor returns true if user has enabled two-factor authentication.
 func (u *User) IsEnabledTwoFactor() bool {
-	return TwoFactors.IsUserEnabled(u.ID)
+	return TwoFactors.IsUserEnabled(context.TODO(), u.ID)
 }
 
 func (u *User) getOrganizationCount(e Engine) (int64, error) {

+ 12 - 10
internal/route/lfs/route_test.go

@@ -31,7 +31,7 @@ func Test_authenticate(t *testing.T) {
 		name                  string
 		header                http.Header
 		mockUsersStore        func() db.UsersStore
-		mockTwoFactorsStore   *db.MockTwoFactorsStore
+		mockTwoFactorsStore   func() db.TwoFactorsStore
 		mockAccessTokensStore func() db.AccessTokensStore
 		expStatusCode         int
 		expHeader             http.Header
@@ -56,10 +56,10 @@ func Test_authenticate(t *testing.T) {
 				mock.AuthenticateFunc.SetDefaultReturn(&db.User{}, nil)
 				return mock
 			},
-			mockTwoFactorsStore: &db.MockTwoFactorsStore{
-				MockIsUserEnabled: func(userID int64) bool {
-					return true
-				},
+			mockTwoFactorsStore: func() db.TwoFactorsStore {
+				mock := db.NewMockTwoFactorsStore()
+				mock.IsUserEnabledFunc.SetDefaultReturn(true)
+				return mock
 			},
 			expStatusCode: http.StatusBadRequest,
 			expHeader:     http.Header{},
@@ -98,10 +98,10 @@ func Test_authenticate(t *testing.T) {
 				mock.AuthenticateFunc.SetDefaultReturn(&db.User{ID: 1, Name: "unknwon"}, nil)
 				return mock
 			},
-			mockTwoFactorsStore: &db.MockTwoFactorsStore{
-				MockIsUserEnabled: func(userID int64) bool {
-					return false
-				},
+			mockTwoFactorsStore: func() db.TwoFactorsStore {
+				mock := db.NewMockTwoFactorsStore()
+				mock.IsUserEnabledFunc.SetDefaultReturn(false)
+				return mock
 			},
 			expStatusCode: http.StatusOK,
 			expHeader:     http.Header{},
@@ -133,7 +133,9 @@ func Test_authenticate(t *testing.T) {
 			if test.mockUsersStore != nil {
 				db.SetMockUsersStore(t, test.mockUsersStore())
 			}
-			db.SetMockTwoFactorsStore(t, test.mockTwoFactorsStore)
+			if test.mockTwoFactorsStore != nil {
+				db.SetMockTwoFactorsStore(t, test.mockTwoFactorsStore())
+			}
 			if test.mockAccessTokensStore != nil {
 				db.SetMockAccessTokensStore(t, test.mockAccessTokensStore())
 			}

+ 1 - 1
internal/route/user/auth.go

@@ -210,7 +210,7 @@ func LoginTwoFactorPost(c *context.Context) {
 		return
 	}
 
-	t, err := db.TwoFactors.GetByUserID(userID)
+	t, err := db.TwoFactors.GetByUserID(c.Req.Context(), userID)
 	if err != nil {
 		c.Error(err, "get two factor by user ID")
 		return

+ 2 - 2
internal/route/user/setting.go

@@ -381,7 +381,7 @@ func SettingsSecurity(c *context.Context) {
 	c.Title("settings.security")
 	c.PageIs("SettingsSecurity")
 
-	t, err := db.TwoFactors.GetByUserID(c.UserID())
+	t, err := db.TwoFactors.GetByUserID(c.Req.Context(), c.UserID())
 	if err != nil && !db.IsErrTwoFactorNotFound(err) {
 		c.Errorf(err, "get two factor by user ID")
 		return
@@ -449,7 +449,7 @@ func SettingsTwoFactorEnablePost(c *context.Context) {
 		return
 	}
 
-	if err := db.TwoFactors.Create(c.UserID(), conf.Security.SecretKey, secret); err != nil {
+	if err := db.TwoFactors.Create(c.Req.Context(), c.UserID(), conf.Security.SecretKey, secret); err != nil {
 		c.Flash.Error(c.Tr("settings.two_factor_enable_error", err))
 		c.RedirectSubpath("/user/settings/security/two_factor_enable")
 		return