瀏覽代碼

all: unwrap `database.LoginSourcesStore` interface (#7694)

Joe Chen 11 月之前
父節點
當前提交
e634aa6277

+ 3 - 2
internal/auth/auth.go

@@ -23,6 +23,8 @@ const (
 	PAM         // 4
 	DLDAP       // 5
 	GitHub      // 6
+
+	Mock Type = 999
 )
 
 // Name returns the human-readable name for given authentication type.
@@ -45,8 +47,7 @@ type ErrBadCredentials struct {
 // IsErrBadCredentials returns true if the underlying error has the type
 // ErrBadCredentials.
 func IsErrBadCredentials(err error) bool {
-	_, ok := errors.Cause(err).(ErrBadCredentials)
-	return ok
+	return errors.As(err, &ErrBadCredentials{})
 }
 
 func (err ErrBadCredentials) Error() string {

+ 9 - 2
internal/database/database.go

@@ -117,13 +117,12 @@ func NewConnection(w logger.Writer) (*gorm.DB, error) {
 		log.Trace("Auto migrated %q", name)
 	}
 
-	sourceFiles, err := loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
+	loadedLoginSourceFilesStore, err = loadLoginSourceFiles(filepath.Join(conf.CustomDir(), "conf", "auth.d"), db.NowFunc)
 	if err != nil {
 		return nil, errors.Wrap(err, "load login source files")
 	}
 
 	// Initialize stores, sorted in alphabetical order.
-	LoginSources = &loginSourcesStore{DB: db, files: sourceFiles}
 	Notices = NewNoticesStore(db)
 	Orgs = NewOrgsStore(db)
 	Perms = NewPermsStore(db)
@@ -166,3 +165,11 @@ func (db *DB) Actions() *ActionsStore {
 func (db *DB) LFS() *LFSStore {
 	return newLFSStore(db.db)
 }
+
+// NOTE: It is not guarded by a mutex because it only gets written during the
+// service start.
+var loadedLoginSourceFilesStore loginSourceFilesStore
+
+func (db *DB) LoginSources() *LoginSourcesStore {
+	return newLoginSourcesStore(db.db, loadedLoginSourceFilesStore)
+}

+ 1 - 2
internal/database/login_source_files.go

@@ -52,8 +52,7 @@ type ErrLoginSourceNotExist struct {
 }
 
 func IsErrLoginSourceNotExist(err error) bool {
-	_, ok := err.(ErrLoginSourceNotExist)
-	return ok
+	return errors.As(err, &ErrLoginSourceNotExist{})
 }
 
 func (err ErrLoginSourceNotExist) Error() string {

+ 55 - 50
internal/database/login_sources.go

@@ -22,30 +22,6 @@ import (
 	"gogs.io/gogs/internal/errutil"
 )
 
-// LoginSourcesStore is the persistent interface for login sources.
-type LoginSourcesStore interface {
-	// Create creates a new login source and persist to database. It returns
-	// ErrLoginSourceAlreadyExist when a login source with same name already exists.
-	Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error)
-	// Count returns the total number of login sources.
-	Count(ctx context.Context) int64
-	// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
-	// if at least one user is associated with the login source.
-	DeleteByID(ctx context.Context, id int64) error
-	// GetByID returns the login source with given ID. It returns
-	// ErrLoginSourceNotExist when not found.
-	GetByID(ctx context.Context, id int64) (*LoginSource, error)
-	// List returns a list of login sources filtered by options.
-	List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error)
-	// ResetNonDefault clears default flag for all the other login sources.
-	ResetNonDefault(ctx context.Context, source *LoginSource) error
-	// Save persists all values of given login source to database or local file. The
-	// Updated field is set to current time automatically.
-	Save(ctx context.Context, t *LoginSource) error
-}
-
-var LoginSources LoginSourcesStore
-
 // LoginSource represents an external way for authorizing users.
 type LoginSource struct {
 	ID        int64 `gorm:"primaryKey"`
@@ -88,6 +64,10 @@ func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
 	return nil
 }
 
+type mockProviderConfig struct {
+	ExternalAccount *auth.ExternalAccount
+}
+
 // AfterFind implements the GORM query hook.
 func (s *LoginSource) AfterFind(_ *gorm.DB) error {
 	s.Created = time.Unix(s.CreatedUnix, 0).Local()
@@ -134,6 +114,16 @@ func (s *LoginSource) AfterFind(_ *gorm.DB) error {
 		}
 		s.Provider = github.NewProvider(&cfg)
 
+	case auth.Mock:
+		var cfg mockProviderConfig
+		err := jsoniter.UnmarshalFromString(s.Config, &cfg)
+		if err != nil {
+			return err
+		}
+		mockProvider := NewMockProvider()
+		mockProvider.AuthenticateFunc.SetDefaultReturn(cfg.ExternalAccount, nil)
+		s.Provider = mockProvider
+
 	default:
 		return fmt.Errorf("unrecognized login source type: %v", s.Type)
 	}
@@ -180,13 +170,19 @@ func (s *LoginSource) GitHub() *github.Config {
 	return s.Provider.Config().(*github.Config)
 }
 
-var _ LoginSourcesStore = (*loginSourcesStore)(nil)
-
-type loginSourcesStore struct {
-	*gorm.DB
+// LoginSourcesStore is the storage layer for login sources.
+type LoginSourcesStore struct {
+	db    *gorm.DB
 	files loginSourceFilesStore
 }
 
+func newLoginSourcesStore(db *gorm.DB, files loginSourceFilesStore) *LoginSourcesStore {
+	return &LoginSourcesStore{
+		db:    db,
+		files: files,
+	}
+}
+
 type CreateLoginSourceOptions struct {
 	Type      auth.Type
 	Name      string
@@ -200,19 +196,20 @@ type ErrLoginSourceAlreadyExist struct {
 }
 
 func IsErrLoginSourceAlreadyExist(err error) bool {
-	_, ok := err.(ErrLoginSourceAlreadyExist)
-	return ok
+	return errors.As(err, &ErrLoginSourceAlreadyExist{})
 }
 
 func (err ErrLoginSourceAlreadyExist) Error() string {
 	return fmt.Sprintf("login source already exists: %v", err.args)
 }
 
-func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
-	err := s.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
+// Create creates a new login source and persists it to the database. It returns
+// ErrLoginSourceAlreadyExist when a login source with same name already exists.
+func (s *LoginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOptions) (*LoginSource, error) {
+	err := s.db.WithContext(ctx).Where("name = ?", opts.Name).First(new(LoginSource)).Error
 	if err == nil {
 		return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
-	} else if err != gorm.ErrRecordNotFound {
+	} else if !errors.Is(err, gorm.ErrRecordNotFound) {
 		return nil, err
 	}
 
@@ -226,12 +223,13 @@ func (s *loginSourcesStore) Create(ctx context.Context, opts CreateLoginSourceOp
 	if err != nil {
 		return nil, err
 	}
-	return source, s.WithContext(ctx).Create(source).Error
+	return source, s.db.WithContext(ctx).Create(source).Error
 }
 
-func (s *loginSourcesStore) Count(ctx context.Context) int64 {
+// Count returns the total number of login sources.
+func (s *LoginSourcesStore) Count(ctx context.Context) int64 {
 	var count int64
-	s.WithContext(ctx).Model(new(LoginSource)).Count(&count)
+	s.db.WithContext(ctx).Model(new(LoginSource)).Count(&count)
 	return count + int64(s.files.Len())
 }
 
@@ -240,31 +238,34 @@ type ErrLoginSourceInUse struct {
 }
 
 func IsErrLoginSourceInUse(err error) bool {
-	_, ok := err.(ErrLoginSourceInUse)
-	return ok
+	return errors.As(err, &ErrLoginSourceInUse{})
 }
 
 func (err ErrLoginSourceInUse) Error() string {
 	return fmt.Sprintf("login source is still used by some users: %v", err.args)
 }
 
-func (s *loginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
+// DeleteByID deletes a login source by given ID. It returns ErrLoginSourceInUse
+// if at least one user is associated with the login source.
+func (s *LoginSourcesStore) DeleteByID(ctx context.Context, id int64) error {
 	var count int64
-	err := s.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
+	err := s.db.WithContext(ctx).Model(new(User)).Where("login_source = ?", id).Count(&count).Error
 	if err != nil {
 		return err
 	} else if count > 0 {
 		return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
 	}
 
-	return s.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
+	return s.db.WithContext(ctx).Where("id = ?", id).Delete(new(LoginSource)).Error
 }
 
-func (s *loginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
+// GetByID returns the login source with given ID. It returns
+// ErrLoginSourceNotExist when not found.
+func (s *LoginSourcesStore) GetByID(ctx context.Context, id int64) (*LoginSource, error) {
 	source := new(LoginSource)
-	err := s.WithContext(ctx).Where("id = ?", id).First(source).Error
+	err := s.db.WithContext(ctx).Where("id = ?", id).First(source).Error
 	if err != nil {
-		if err == gorm.ErrRecordNotFound {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
 			return s.files.GetByID(id)
 		}
 		return nil, err
@@ -277,9 +278,10 @@ type ListLoginSourceOptions struct {
 	OnlyActivated bool
 }
 
-func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
+// List returns a list of login sources filtered by options.
+func (s *LoginSourcesStore) List(ctx context.Context, opts ListLoginSourceOptions) ([]*LoginSource, error) {
 	var sources []*LoginSource
-	query := s.WithContext(ctx).Order("id ASC")
+	query := s.db.WithContext(ctx).Order("id ASC")
 	if opts.OnlyActivated {
 		query = query.Where("is_actived = ?", true)
 	}
@@ -291,8 +293,9 @@ func (s *loginSourcesStore) List(ctx context.Context, opts ListLoginSourceOption
 	return append(sources, s.files.List(opts)...), nil
 }
 
-func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
-	err := s.WithContext(ctx).
+// ResetNonDefault clears default flag for all the other login sources.
+func (s *LoginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSource) error {
+	err := s.db.WithContext(ctx).
 		Model(new(LoginSource)).
 		Where("id != ?", dflt.ID).
 		Updates(map[string]any{"is_default": false}).
@@ -314,9 +317,11 @@ func (s *loginSourcesStore) ResetNonDefault(ctx context.Context, dflt *LoginSour
 	return nil
 }
 
-func (s *loginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
+// Save persists all values of given login source to database or local file. The
+// Updated field is set to current time automatically.
+func (s *LoginSourcesStore) Save(ctx context.Context, source *LoginSource) error {
 	if source.File == nil {
-		return s.WithContext(ctx).Save(source).Error
+		return s.db.WithContext(ctx).Save(source).Error
 	}
 
 	source.File.SetGeneral("name", source.Name)

+ 57 - 49
internal/database/login_sources_test.go

@@ -163,13 +163,13 @@ func TestLoginSources(t *testing.T) {
 	t.Parallel()
 
 	ctx := context.Background()
-	db := &loginSourcesStore{
-		DB: newTestDB(t, "loginSourcesStore"),
+	s := &LoginSourcesStore{
+		db: newTestDB(t, "LoginSourcesStore"),
 	}
 
 	for _, tc := range []struct {
 		name string
-		test func(t *testing.T, ctx context.Context, db *loginSourcesStore)
+		test func(t *testing.T, ctx context.Context, s *LoginSourcesStore)
 	}{
 		{"Create", loginSourcesCreate},
 		{"Count", loginSourcesCount},
@@ -181,10 +181,10 @@ func TestLoginSources(t *testing.T) {
 	} {
 		t.Run(tc.name, func(t *testing.T) {
 			t.Cleanup(func() {
-				err := clearTables(t, db.DB)
+				err := clearTables(t, s.db)
 				require.NoError(t, err)
 			})
-			tc.test(t, ctx, db)
+			tc.test(t, ctx, s)
 		})
 		if t.Failed() {
 			break
@@ -192,9 +192,9 @@ func TestLoginSources(t *testing.T) {
 	}
 }
 
-func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore) {
+func loginSourcesCreate(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
 	// Create first login source with name "GitHub"
-	source, err := db.Create(ctx,
+	source, err := s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type:      auth.GitHub,
 			Name:      "GitHub",
@@ -208,20 +208,28 @@ func loginSourcesCreate(t *testing.T, ctx context.Context, db *loginSourcesStore
 	require.NoError(t, err)
 
 	// Get it back and check the Created field
-	source, err = db.GetByID(ctx, source.ID)
+	source, err = s.GetByID(ctx, source.ID)
 	require.NoError(t, err)
-	assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
-	assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
+	assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
+	assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
 
-	// Try create second login source with same name should fail
-	_, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
+	// Try to create second login source with same name should fail.
+	_, err = s.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
 	wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
 	assert.Equal(t, wantErr, err)
 }
 
-func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore) {
+func setMockLoginSourceFilesStore(t *testing.T, s *LoginSourcesStore, mock loginSourceFilesStore) {
+	before := s.files
+	s.files = mock
+	t.Cleanup(func() {
+		s.files = before
+	})
+}
+
+func loginSourcesCount(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
 	// Create two login sources, one in database and one as source file.
-	_, err := db.Create(ctx,
+	_, err := s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type:      auth.GitHub,
 			Name:      "GitHub",
@@ -236,14 +244,14 @@ func loginSourcesCount(t *testing.T, ctx context.Context, db *loginSourcesStore)
 
 	mock := NewMockLoginSourceFilesStore()
 	mock.LenFunc.SetDefaultReturn(2)
-	setMockLoginSourceFilesStore(t, db, mock)
+	setMockLoginSourceFilesStore(t, s, mock)
 
-	assert.Equal(t, int64(3), db.Count(ctx))
+	assert.Equal(t, int64(3), s.Count(ctx))
 }
 
-func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
+func loginSourcesDeleteByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
 	t.Run("delete but in used", func(t *testing.T) {
-		source, err := db.Create(ctx,
+		source, err := s.Create(ctx,
 			CreateLoginSourceOptions{
 				Type:      auth.GitHub,
 				Name:      "GitHub",
@@ -257,7 +265,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
 		require.NoError(t, err)
 
 		// Create a user that uses this login source
-		_, err = (&usersStore{DB: db.DB}).Create(ctx, "alice", "",
+		_, err = NewUsersStore(s.db).Create(ctx, "alice", "",
 			CreateUserOptions{
 				LoginSource: source.ID,
 			},
@@ -265,7 +273,7 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
 		require.NoError(t, err)
 
 		// Delete the login source will result in error
-		err = db.DeleteByID(ctx, source.ID)
+		err = s.DeleteByID(ctx, source.ID)
 		wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
 		assert.Equal(t, wantErr, err)
 	})
@@ -274,10 +282,10 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
 	mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
 		return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
 	})
-	setMockLoginSourceFilesStore(t, db, mock)
+	setMockLoginSourceFilesStore(t, s, mock)
 
 	// Create a login source with name "GitHub2"
-	source, err := db.Create(ctx,
+	source, err := s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type:      auth.GitHub,
 			Name:      "GitHub2",
@@ -291,24 +299,24 @@ func loginSourcesDeleteByID(t *testing.T, ctx context.Context, db *loginSourcesS
 	require.NoError(t, err)
 
 	// Delete a non-existent ID is noop
-	err = db.DeleteByID(ctx, 9999)
+	err = s.DeleteByID(ctx, 9999)
 	require.NoError(t, err)
 
 	// We should be able to get it back
-	_, err = db.GetByID(ctx, source.ID)
+	_, err = s.GetByID(ctx, source.ID)
 	require.NoError(t, err)
 
 	// Now delete this login source with ID
-	err = db.DeleteByID(ctx, source.ID)
+	err = s.DeleteByID(ctx, source.ID)
 	require.NoError(t, err)
 
 	// We should get token not found error
-	_, err = db.GetByID(ctx, source.ID)
+	_, err = s.GetByID(ctx, source.ID)
 	wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
 	assert.Equal(t, wantErr, err)
 }
 
-func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStore) {
+func loginSourcesGetByID(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
 	mock := NewMockLoginSourceFilesStore()
 	mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
 		if id != 101 {
@@ -316,14 +324,14 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
 		}
 		return &LoginSource{ID: id}, nil
 	})
-	setMockLoginSourceFilesStore(t, db, mock)
+	setMockLoginSourceFilesStore(t, s, mock)
 
 	expConfig := &github.Config{
 		APIEndpoint: "https://api.github.com",
 	}
 
 	// Create a login source with name "GitHub"
-	source, err := db.Create(ctx,
+	source, err := s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type:      auth.GitHub,
 			Name:      "GitHub",
@@ -335,16 +343,16 @@ func loginSourcesGetByID(t *testing.T, ctx context.Context, db *loginSourcesStor
 	require.NoError(t, err)
 
 	// Get the one in the database and test the read/write hooks
-	source, err = db.GetByID(ctx, source.ID)
+	source, err = s.GetByID(ctx, source.ID)
 	require.NoError(t, err)
 	assert.Equal(t, expConfig, source.Provider.Config())
 
 	// Get the one in source file store
-	_, err = db.GetByID(ctx, 101)
+	_, err = s.GetByID(ctx, 101)
 	require.NoError(t, err)
 }
 
-func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore) {
+func loginSourcesList(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
 	mock := NewMockLoginSourceFilesStore()
 	mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
 		if opts.OnlyActivated {
@@ -357,10 +365,10 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
 			{ID: 2},
 		}
 	})
-	setMockLoginSourceFilesStore(t, db, mock)
+	setMockLoginSourceFilesStore(t, s, mock)
 
 	// Create two login sources in database, one activated and the other one not
-	_, err := db.Create(ctx,
+	_, err := s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type: auth.PAM,
 			Name: "PAM",
@@ -370,7 +378,7 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
 		},
 	)
 	require.NoError(t, err)
-	_, err = db.Create(ctx,
+	_, err = s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type:      auth.GitHub,
 			Name:      "GitHub",
@@ -383,17 +391,17 @@ func loginSourcesList(t *testing.T, ctx context.Context, db *loginSourcesStore)
 	require.NoError(t, err)
 
 	// List all login sources
-	sources, err := db.List(ctx, ListLoginSourceOptions{})
+	sources, err := s.List(ctx, ListLoginSourceOptions{})
 	require.NoError(t, err)
 	assert.Equal(t, 4, len(sources), "number of sources")
 
 	// Only list activated login sources
-	sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
+	sources, err = s.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
 	require.NoError(t, err)
 	assert.Equal(t, 2, len(sources), "number of sources")
 }
 
-func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSourcesStore) {
+func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
 	mock := NewMockLoginSourceFilesStore()
 	mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
 		mockFile := NewMockLoginSourceFileStore()
@@ -407,10 +415,10 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
 			},
 		}
 	})
-	setMockLoginSourceFilesStore(t, db, mock)
+	setMockLoginSourceFilesStore(t, s, mock)
 
 	// Create two login sources both have default on
-	source1, err := db.Create(ctx,
+	source1, err := s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type:    auth.PAM,
 			Name:    "PAM",
@@ -421,7 +429,7 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
 		},
 	)
 	require.NoError(t, err)
-	source2, err := db.Create(ctx,
+	source2, err := s.Create(ctx,
 		CreateLoginSourceOptions{
 			Type:      auth.GitHub,
 			Name:      "GitHub",
@@ -435,23 +443,23 @@ func loginSourcesResetNonDefault(t *testing.T, ctx context.Context, db *loginSou
 	require.NoError(t, err)
 
 	// Set source 1 as default
-	err = db.ResetNonDefault(ctx, source1)
+	err = s.ResetNonDefault(ctx, source1)
 	require.NoError(t, err)
 
 	// Verify the default state
-	source1, err = db.GetByID(ctx, source1.ID)
+	source1, err = s.GetByID(ctx, source1.ID)
 	require.NoError(t, err)
 	assert.True(t, source1.IsDefault)
 
-	source2, err = db.GetByID(ctx, source2.ID)
+	source2, err = s.GetByID(ctx, source2.ID)
 	require.NoError(t, err)
 	assert.False(t, source2.IsDefault)
 }
 
-func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore) {
+func loginSourcesSave(t *testing.T, ctx context.Context, s *LoginSourcesStore) {
 	t.Run("save to database", func(t *testing.T) {
 		// Create a login source with name "GitHub"
-		source, err := db.Create(ctx,
+		source, err := s.Create(ctx,
 			CreateLoginSourceOptions{
 				Type:      auth.GitHub,
 				Name:      "GitHub",
@@ -468,10 +476,10 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
 		source.Provider = github.NewProvider(&github.Config{
 			APIEndpoint: "https://api2.github.com",
 		})
-		err = db.Save(ctx, source)
+		err = s.Save(ctx, source)
 		require.NoError(t, err)
 
-		source, err = db.GetByID(ctx, source.ID)
+		source, err = s.GetByID(ctx, source.ID)
 		require.NoError(t, err)
 		assert.False(t, source.IsActived)
 		assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
@@ -485,7 +493,7 @@ func loginSourcesSave(t *testing.T, ctx context.Context, db *loginSourcesStore)
 			}),
 			File: mockFile,
 		}
-		err := db.Save(ctx, source)
+		err := s.Save(ctx, source)
 		require.NoError(t, err)
 		mockrequire.Called(t, mockFile.SaveFunc)
 	})

+ 0 - 16
internal/database/mocks.go

@@ -8,22 +8,6 @@ import (
 	"testing"
 )
 
-func setMockLoginSourcesStore(t *testing.T, mock LoginSourcesStore) {
-	before := LoginSources
-	LoginSources = mock
-	t.Cleanup(func() {
-		LoginSources = before
-	})
-}
-
-func setMockLoginSourceFilesStore(t *testing.T, db *loginSourcesStore, mock loginSourceFilesStore) {
-	before := db.files
-	db.files = mock
-	t.Cleanup(func() {
-		db.files = before
-	})
-}
-
 func SetMockPermsStore(t *testing.T, mock PermsStore) {
 	before := Perms
 	Perms = mock

+ 620 - 0
internal/database/mocks_gen.go

@@ -0,0 +1,620 @@
+// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
+//
+// This file was generated by running `go-mockgen` at the root of this repository.
+// To add additional mocks to this or another package, add a new entry to the
+// mockgen.yaml file in the root of this repository.
+
+package database
+
+import (
+	"sync"
+
+	auth "gogs.io/gogs/internal/auth"
+)
+
+// MockProvider is a mock implementation of the Provider interface (from the
+// package gogs.io/gogs/internal/auth) used for unit testing.
+type MockProvider struct {
+	// AuthenticateFunc is an instance of a mock function object controlling
+	// the behavior of the method Authenticate.
+	AuthenticateFunc *ProviderAuthenticateFunc
+	// ConfigFunc is an instance of a mock function object controlling the
+	// behavior of the method Config.
+	ConfigFunc *ProviderConfigFunc
+	// HasTLSFunc is an instance of a mock function object controlling the
+	// behavior of the method HasTLS.
+	HasTLSFunc *ProviderHasTLSFunc
+	// SkipTLSVerifyFunc is an instance of a mock function object
+	// controlling the behavior of the method SkipTLSVerify.
+	SkipTLSVerifyFunc *ProviderSkipTLSVerifyFunc
+	// UseTLSFunc is an instance of a mock function object controlling the
+	// behavior of the method UseTLS.
+	UseTLSFunc *ProviderUseTLSFunc
+}
+
+// NewMockProvider creates a new mock of the Provider interface. All methods
+// return zero values for all results, unless overwritten.
+func NewMockProvider() *MockProvider {
+	return &MockProvider{
+		AuthenticateFunc: &ProviderAuthenticateFunc{
+			defaultHook: func(string, string) (r0 *auth.ExternalAccount, r1 error) {
+				return
+			},
+		},
+		ConfigFunc: &ProviderConfigFunc{
+			defaultHook: func() (r0 interface{}) {
+				return
+			},
+		},
+		HasTLSFunc: &ProviderHasTLSFunc{
+			defaultHook: func() (r0 bool) {
+				return
+			},
+		},
+		SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
+			defaultHook: func() (r0 bool) {
+				return
+			},
+		},
+		UseTLSFunc: &ProviderUseTLSFunc{
+			defaultHook: func() (r0 bool) {
+				return
+			},
+		},
+	}
+}
+
+// NewStrictMockProvider creates a new mock of the Provider interface. All
+// methods panic on invocation, unless overwritten.
+func NewStrictMockProvider() *MockProvider {
+	return &MockProvider{
+		AuthenticateFunc: &ProviderAuthenticateFunc{
+			defaultHook: func(string, string) (*auth.ExternalAccount, error) {
+				panic("unexpected invocation of MockProvider.Authenticate")
+			},
+		},
+		ConfigFunc: &ProviderConfigFunc{
+			defaultHook: func() interface{} {
+				panic("unexpected invocation of MockProvider.Config")
+			},
+		},
+		HasTLSFunc: &ProviderHasTLSFunc{
+			defaultHook: func() bool {
+				panic("unexpected invocation of MockProvider.HasTLS")
+			},
+		},
+		SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
+			defaultHook: func() bool {
+				panic("unexpected invocation of MockProvider.SkipTLSVerify")
+			},
+		},
+		UseTLSFunc: &ProviderUseTLSFunc{
+			defaultHook: func() bool {
+				panic("unexpected invocation of MockProvider.UseTLS")
+			},
+		},
+	}
+}
+
+// NewMockProviderFrom creates a new mock of the MockProvider interface. All
+// methods delegate to the given implementation, unless overwritten.
+func NewMockProviderFrom(i auth.Provider) *MockProvider {
+	return &MockProvider{
+		AuthenticateFunc: &ProviderAuthenticateFunc{
+			defaultHook: i.Authenticate,
+		},
+		ConfigFunc: &ProviderConfigFunc{
+			defaultHook: i.Config,
+		},
+		HasTLSFunc: &ProviderHasTLSFunc{
+			defaultHook: i.HasTLS,
+		},
+		SkipTLSVerifyFunc: &ProviderSkipTLSVerifyFunc{
+			defaultHook: i.SkipTLSVerify,
+		},
+		UseTLSFunc: &ProviderUseTLSFunc{
+			defaultHook: i.UseTLS,
+		},
+	}
+}
+
+// ProviderAuthenticateFunc describes the behavior when the Authenticate
+// method of the parent MockProvider instance is invoked.
+type ProviderAuthenticateFunc struct {
+	defaultHook func(string, string) (*auth.ExternalAccount, error)
+	hooks       []func(string, string) (*auth.ExternalAccount, error)
+	history     []ProviderAuthenticateFuncCall
+	mutex       sync.Mutex
+}
+
+// Authenticate delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockProvider) Authenticate(v0 string, v1 string) (*auth.ExternalAccount, error) {
+	r0, r1 := m.AuthenticateFunc.nextHook()(v0, v1)
+	m.AuthenticateFunc.appendCall(ProviderAuthenticateFuncCall{v0, v1, r0, r1})
+	return r0, r1
+}
+
+// SetDefaultHook sets function that is called when the Authenticate method
+// of the parent MockProvider instance is invoked and the hook queue is
+// empty.
+func (f *ProviderAuthenticateFunc) SetDefaultHook(hook func(string, string) (*auth.ExternalAccount, error)) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// Authenticate method of the parent MockProvider instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *ProviderAuthenticateFunc) PushHook(hook func(string, string) (*auth.ExternalAccount, error)) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *ProviderAuthenticateFunc) SetDefaultReturn(r0 *auth.ExternalAccount, r1 error) {
+	f.SetDefaultHook(func(string, string) (*auth.ExternalAccount, error) {
+		return r0, r1
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *ProviderAuthenticateFunc) PushReturn(r0 *auth.ExternalAccount, r1 error) {
+	f.PushHook(func(string, string) (*auth.ExternalAccount, error) {
+		return r0, r1
+	})
+}
+
+func (f *ProviderAuthenticateFunc) nextHook() func(string, string) (*auth.ExternalAccount, error) {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *ProviderAuthenticateFunc) appendCall(r0 ProviderAuthenticateFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of ProviderAuthenticateFuncCall objects
+// describing the invocations of this function.
+func (f *ProviderAuthenticateFunc) History() []ProviderAuthenticateFuncCall {
+	f.mutex.Lock()
+	history := make([]ProviderAuthenticateFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// ProviderAuthenticateFuncCall is an object that describes an invocation of
+// method Authenticate on an instance of MockProvider.
+type ProviderAuthenticateFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 string
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 string
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 *auth.ExternalAccount
+	// Result1 is the value of the 2nd result returned from this method
+	// invocation.
+	Result1 error
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c ProviderAuthenticateFuncCall) Args() []interface{} {
+	return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c ProviderAuthenticateFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0, c.Result1}
+}
+
+// ProviderConfigFunc describes the behavior when the Config method of the
+// parent MockProvider instance is invoked.
+type ProviderConfigFunc struct {
+	defaultHook func() interface{}
+	hooks       []func() interface{}
+	history     []ProviderConfigFuncCall
+	mutex       sync.Mutex
+}
+
+// Config delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockProvider) Config() interface{} {
+	r0 := m.ConfigFunc.nextHook()()
+	m.ConfigFunc.appendCall(ProviderConfigFuncCall{r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the Config method of the
+// parent MockProvider instance is invoked and the hook queue is empty.
+func (f *ProviderConfigFunc) SetDefaultHook(hook func() interface{}) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// Config method of the parent MockProvider instance invokes the hook at the
+// front of the queue and discards it. After the queue is empty, the default
+// hook function is invoked for any future action.
+func (f *ProviderConfigFunc) PushHook(hook func() interface{}) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *ProviderConfigFunc) SetDefaultReturn(r0 interface{}) {
+	f.SetDefaultHook(func() interface{} {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *ProviderConfigFunc) PushReturn(r0 interface{}) {
+	f.PushHook(func() interface{} {
+		return r0
+	})
+}
+
+func (f *ProviderConfigFunc) nextHook() func() interface{} {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *ProviderConfigFunc) appendCall(r0 ProviderConfigFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of ProviderConfigFuncCall objects describing
+// the invocations of this function.
+func (f *ProviderConfigFunc) History() []ProviderConfigFuncCall {
+	f.mutex.Lock()
+	history := make([]ProviderConfigFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// ProviderConfigFuncCall is an object that describes an invocation of
+// method Config on an instance of MockProvider.
+type ProviderConfigFuncCall struct {
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 interface{}
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c ProviderConfigFuncCall) Args() []interface{} {
+	return []interface{}{}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c ProviderConfigFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}
+
+// ProviderHasTLSFunc describes the behavior when the HasTLS method of the
+// parent MockProvider instance is invoked.
+type ProviderHasTLSFunc struct {
+	defaultHook func() bool
+	hooks       []func() bool
+	history     []ProviderHasTLSFuncCall
+	mutex       sync.Mutex
+}
+
+// HasTLS delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockProvider) HasTLS() bool {
+	r0 := m.HasTLSFunc.nextHook()()
+	m.HasTLSFunc.appendCall(ProviderHasTLSFuncCall{r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the HasTLS method of the
+// parent MockProvider instance is invoked and the hook queue is empty.
+func (f *ProviderHasTLSFunc) SetDefaultHook(hook func() bool) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// HasTLS method of the parent MockProvider instance invokes the hook at the
+// front of the queue and discards it. After the queue is empty, the default
+// hook function is invoked for any future action.
+func (f *ProviderHasTLSFunc) PushHook(hook func() bool) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *ProviderHasTLSFunc) SetDefaultReturn(r0 bool) {
+	f.SetDefaultHook(func() bool {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *ProviderHasTLSFunc) PushReturn(r0 bool) {
+	f.PushHook(func() bool {
+		return r0
+	})
+}
+
+func (f *ProviderHasTLSFunc) nextHook() func() bool {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *ProviderHasTLSFunc) appendCall(r0 ProviderHasTLSFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of ProviderHasTLSFuncCall objects describing
+// the invocations of this function.
+func (f *ProviderHasTLSFunc) History() []ProviderHasTLSFuncCall {
+	f.mutex.Lock()
+	history := make([]ProviderHasTLSFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// ProviderHasTLSFuncCall is an object that describes an invocation of
+// method HasTLS on an instance of MockProvider.
+type ProviderHasTLSFuncCall struct {
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 bool
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c ProviderHasTLSFuncCall) Args() []interface{} {
+	return []interface{}{}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c ProviderHasTLSFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}
+
+// ProviderSkipTLSVerifyFunc describes the behavior when the SkipTLSVerify
+// method of the parent MockProvider instance is invoked.
+type ProviderSkipTLSVerifyFunc struct {
+	defaultHook func() bool
+	hooks       []func() bool
+	history     []ProviderSkipTLSVerifyFuncCall
+	mutex       sync.Mutex
+}
+
+// SkipTLSVerify delegates to the next hook function in the queue and stores
+// the parameter and result values of this invocation.
+func (m *MockProvider) SkipTLSVerify() bool {
+	r0 := m.SkipTLSVerifyFunc.nextHook()()
+	m.SkipTLSVerifyFunc.appendCall(ProviderSkipTLSVerifyFuncCall{r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the SkipTLSVerify method
+// of the parent MockProvider instance is invoked and the hook queue is
+// empty.
+func (f *ProviderSkipTLSVerifyFunc) SetDefaultHook(hook func() bool) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// SkipTLSVerify method of the parent MockProvider instance invokes the hook
+// at the front of the queue and discards it. After the queue is empty, the
+// default hook function is invoked for any future action.
+func (f *ProviderSkipTLSVerifyFunc) PushHook(hook func() bool) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *ProviderSkipTLSVerifyFunc) SetDefaultReturn(r0 bool) {
+	f.SetDefaultHook(func() bool {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *ProviderSkipTLSVerifyFunc) PushReturn(r0 bool) {
+	f.PushHook(func() bool {
+		return r0
+	})
+}
+
+func (f *ProviderSkipTLSVerifyFunc) nextHook() func() bool {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *ProviderSkipTLSVerifyFunc) appendCall(r0 ProviderSkipTLSVerifyFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of ProviderSkipTLSVerifyFuncCall objects
+// describing the invocations of this function.
+func (f *ProviderSkipTLSVerifyFunc) History() []ProviderSkipTLSVerifyFuncCall {
+	f.mutex.Lock()
+	history := make([]ProviderSkipTLSVerifyFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// ProviderSkipTLSVerifyFuncCall is an object that describes an invocation
+// of method SkipTLSVerify on an instance of MockProvider.
+type ProviderSkipTLSVerifyFuncCall struct {
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 bool
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c ProviderSkipTLSVerifyFuncCall) Args() []interface{} {
+	return []interface{}{}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c ProviderSkipTLSVerifyFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}
+
+// ProviderUseTLSFunc describes the behavior when the UseTLS method of the
+// parent MockProvider instance is invoked.
+type ProviderUseTLSFunc struct {
+	defaultHook func() bool
+	hooks       []func() bool
+	history     []ProviderUseTLSFuncCall
+	mutex       sync.Mutex
+}
+
+// UseTLS delegates to the next hook function in the queue and stores the
+// parameter and result values of this invocation.
+func (m *MockProvider) UseTLS() bool {
+	r0 := m.UseTLSFunc.nextHook()()
+	m.UseTLSFunc.appendCall(ProviderUseTLSFuncCall{r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the UseTLS method of the
+// parent MockProvider instance is invoked and the hook queue is empty.
+func (f *ProviderUseTLSFunc) SetDefaultHook(hook func() bool) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// UseTLS method of the parent MockProvider instance invokes the hook at the
+// front of the queue and discards it. After the queue is empty, the default
+// hook function is invoked for any future action.
+func (f *ProviderUseTLSFunc) PushHook(hook func() bool) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *ProviderUseTLSFunc) SetDefaultReturn(r0 bool) {
+	f.SetDefaultHook(func() bool {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *ProviderUseTLSFunc) PushReturn(r0 bool) {
+	f.PushHook(func() bool {
+		return r0
+	})
+}
+
+func (f *ProviderUseTLSFunc) nextHook() func() bool {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *ProviderUseTLSFunc) appendCall(r0 ProviderUseTLSFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of ProviderUseTLSFuncCall objects describing
+// the invocations of this function.
+func (f *ProviderUseTLSFunc) History() []ProviderUseTLSFuncCall {
+	f.mutex.Lock()
+	history := make([]ProviderUseTLSFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// ProviderUseTLSFuncCall is an object that describes an invocation of
+// method UseTLS on an instance of MockProvider.
+type ProviderUseTLSFuncCall struct {
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 bool
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c ProviderUseTLSFuncCall) Args() []interface{} {
+	return []interface{}{}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c ProviderUseTLSFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}

File diff suppressed because it is too large
+ 255 - 1658
internal/database/mocks_test.go


+ 1 - 1
internal/database/models.go

@@ -224,7 +224,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
 	stats.Counter.Follow, _ = x.Count(new(Follow))
 	stats.Counter.Mirror, _ = x.Count(new(Mirror))
 	stats.Counter.Release, _ = x.Count(new(Release))
-	stats.Counter.LoginSource = LoginSources.Count(ctx)
+	stats.Counter.LoginSource = Handle.LoginSources().Count(ctx)
 	stats.Counter.Webhook, _ = x.Count(new(Webhook))
 	stats.Counter.Milestone, _ = x.Count(new(Milestone))
 	stats.Counter.Label, _ = x.Count(new(Label))

+ 2 - 2
internal/database/users.go

@@ -185,7 +185,7 @@ func (s *usersStore) Authenticate(ctx context.Context, login, password string, l
 
 	user := new(User)
 	err := query.First(user).Error
-	if err != nil && err != gorm.ErrRecordNotFound {
+	if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
 		return nil, errors.Wrap(err, "get user")
 	}
 
@@ -221,7 +221,7 @@ func (s *usersStore) Authenticate(ctx context.Context, login, password string, l
 		createNewUser = true
 	}
 
-	source, err := LoginSources.GetByID(ctx, authSourceID)
+	source, err := newLoginSourcesStore(s.DB, loadedLoginSourceFilesStore).GetByID(ctx, authSourceID)
 	if err != nil {
 		return nil, errors.Wrap(err, "get login source")
 	}

+ 30 - 29
internal/database/users_test.go

@@ -175,17 +175,19 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
 	})
 
 	t.Run("via login source", func(t *testing.T) {
-		mockLoginSources := NewMockLoginSourcesStore()
-		mockLoginSources.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int64) (*LoginSource, error) {
-			mockProvider := NewMockProvider()
-			mockProvider.AuthenticateFunc.SetDefaultReturn(&auth.ExternalAccount{}, nil)
-			s := &LoginSource{
-				IsActived: true,
-				Provider:  mockProvider,
-			}
-			return s, nil
-		})
-		setMockLoginSourcesStore(t, mockLoginSources)
+		loginSourcesStore := newLoginSourcesStore(db.DB, NewMockLoginSourceFilesStore())
+		loginSource, err := loginSourcesStore.Create(
+			ctx,
+			CreateLoginSourceOptions{
+				Type:      auth.Mock,
+				Name:      "mock-1",
+				Activated: true,
+				Config: mockProviderConfig{
+					ExternalAccount: &auth.ExternalAccount{},
+				},
+			},
+		)
+		require.NoError(t, err)
 
 		bob, err := db.Create(ctx, "bob", "[email protected]",
 			CreateUserOptions{
@@ -195,31 +197,30 @@ func usersAuthenticate(t *testing.T, ctx context.Context, db *usersStore) {
 		)
 		require.NoError(t, err)
 
-		user, err := db.Authenticate(ctx, bob.Email, password, 1)
+		user, err := db.Authenticate(ctx, bob.Email, password, loginSource.ID)
 		require.NoError(t, err)
 		assert.Equal(t, bob.Name, user.Name)
 	})
 
 	t.Run("new user via login source", func(t *testing.T) {
-		mockLoginSources := NewMockLoginSourcesStore()
-		mockLoginSources.GetByIDFunc.SetDefaultHook(func(ctx context.Context, id int64) (*LoginSource, error) {
-			mockProvider := NewMockProvider()
-			mockProvider.AuthenticateFunc.SetDefaultReturn(
-				&auth.ExternalAccount{
-					Name:  "cindy",
-					Email: "[email protected]",
+		loginSourcesStore := newLoginSourcesStore(db.DB, NewMockLoginSourceFilesStore())
+		loginSource, err := loginSourcesStore.Create(
+			ctx,
+			CreateLoginSourceOptions{
+				Type:      auth.Mock,
+				Name:      "mock-2",
+				Activated: true,
+				Config: mockProviderConfig{
+					ExternalAccount: &auth.ExternalAccount{
+						Name:  "cindy",
+						Email: "[email protected]",
+					},
 				},
-				nil,
-			)
-			s := &LoginSource{
-				IsActived: true,
-				Provider:  mockProvider,
-			}
-			return s, nil
-		})
-		setMockLoginSourcesStore(t, mockLoginSources)
+			},
+		)
+		require.NoError(t, err)
 
-		user, err := db.Authenticate(ctx, "cindy", password, 1)
+		user, err := db.Authenticate(ctx, "cindy", password, loginSource.ID)
 		require.NoError(t, err)
 		assert.Equal(t, "cindy", user.Name)
 

+ 9 - 9
internal/route/admin/auths.go

@@ -35,13 +35,13 @@ func Authentications(c *context.Context) {
 	c.PageIs("AdminAuthentications")
 
 	var err error
-	c.Data["Sources"], err = database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
+	c.Data["Sources"], err = database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
 	if err != nil {
 		c.Error(err, "list login sources")
 		return
 	}
 
-	c.Data["Total"] = database.LoginSources.Count(c.Req.Context())
+	c.Data["Total"] = database.Handle.LoginSources().Count(c.Req.Context())
 	c.Success(AUTHS)
 }
 
@@ -159,7 +159,7 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
 		return
 	}
 
-	source, err := database.LoginSources.Create(c.Req.Context(),
+	source, err := database.Handle.LoginSources().Create(c.Req.Context(),
 		database.CreateLoginSourceOptions{
 			Type:      auth.Type(f.Type),
 			Name:      f.Name,
@@ -179,7 +179,7 @@ func NewAuthSourcePost(c *context.Context, f form.Authentication) {
 	}
 
 	if source.IsDefault {
-		err = database.LoginSources.ResetNonDefault(c.Req.Context(), source)
+		err = database.Handle.LoginSources().ResetNonDefault(c.Req.Context(), source)
 		if err != nil {
 			c.Error(err, "reset non-default login sources")
 			return
@@ -200,7 +200,7 @@ func EditAuthSource(c *context.Context) {
 	c.Data["SecurityProtocols"] = securityProtocols
 	c.Data["SMTPAuths"] = smtp.AuthTypes
 
-	source, err := database.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
+	source, err := database.Handle.LoginSources().GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
 	if err != nil {
 		c.Error(err, "get login source by ID")
 		return
@@ -218,7 +218,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
 
 	c.Data["SMTPAuths"] = smtp.AuthTypes
 
-	source, err := database.LoginSources.GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
+	source, err := database.Handle.LoginSources().GetByID(c.Req.Context(), c.ParamsInt64(":authid"))
 	if err != nil {
 		c.Error(err, "get login source by ID")
 		return
@@ -257,13 +257,13 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
 	source.IsActived = f.IsActive
 	source.IsDefault = f.IsDefault
 	source.Provider = provider
-	if err := database.LoginSources.Save(c.Req.Context(), source); err != nil {
+	if err := database.Handle.LoginSources().Save(c.Req.Context(), source); err != nil {
 		c.Error(err, "update login source")
 		return
 	}
 
 	if source.IsDefault {
-		err = database.LoginSources.ResetNonDefault(c.Req.Context(), source)
+		err = database.Handle.LoginSources().ResetNonDefault(c.Req.Context(), source)
 		if err != nil {
 			c.Error(err, "reset non-default login sources")
 			return
@@ -278,7 +278,7 @@ func EditAuthSourcePost(c *context.Context, f form.Authentication) {
 
 func DeleteAuthSource(c *context.Context) {
 	id := c.ParamsInt64(":authid")
-	if err := database.LoginSources.DeleteByID(c.Req.Context(), id); err != nil {
+	if err := database.Handle.LoginSources().DeleteByID(c.Req.Context(), id); err != nil {
 		if database.IsErrLoginSourceInUse(err) {
 			c.Flash.Error(c.Tr("admin.auths.still_in_used"))
 		} else {

+ 4 - 4
internal/route/admin/users.go

@@ -46,7 +46,7 @@ func NewUser(c *context.Context) {
 
 	c.Data["login_type"] = "0-0"
 
-	sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
+	sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
 	if err != nil {
 		c.Error(err, "list login sources")
 		return
@@ -62,7 +62,7 @@ func NewUserPost(c *context.Context, f form.AdminCrateUser) {
 	c.Data["PageIsAdmin"] = true
 	c.Data["PageIsAdminUsers"] = true
 
-	sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
+	sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
 	if err != nil {
 		c.Error(err, "list login sources")
 		return
@@ -125,7 +125,7 @@ func prepareUserInfo(c *context.Context) *database.User {
 	c.Data["User"] = u
 
 	if u.LoginSource > 0 {
-		c.Data["LoginSource"], err = database.LoginSources.GetByID(c.Req.Context(), u.LoginSource)
+		c.Data["LoginSource"], err = database.Handle.LoginSources().GetByID(c.Req.Context(), u.LoginSource)
 		if err != nil {
 			c.Error(err, "get login source by ID")
 			return nil
@@ -134,7 +134,7 @@ func prepareUserInfo(c *context.Context) *database.User {
 		c.Data["LoginSource"] = &database.LoginSource{}
 	}
 
-	sources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{})
+	sources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{})
 	if err != nil {
 		c.Error(err, "list login sources")
 		return nil

+ 1 - 1
internal/route/api/v1/admin/user.go

@@ -22,7 +22,7 @@ func parseLoginSource(c *context.APIContext, sourceID int64) {
 		return
 	}
 
-	_, err := database.LoginSources.GetByID(c.Req.Context(), sourceID)
+	_, err := database.Handle.LoginSources().GetByID(c.Req.Context(), sourceID)
 	if err != nil {
 		if database.IsErrLoginSourceNotExist(err) {
 			c.ErrorStatus(http.StatusUnprocessableEntity, err)

+ 2 - 2
internal/route/user/auth.go

@@ -106,7 +106,7 @@ func Login(c *context.Context) {
 	}
 
 	// Display normal login page
-	loginSources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
+	loginSources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
 	if err != nil {
 		c.Error(err, "list activated login sources")
 		return
@@ -153,7 +153,7 @@ func afterLogin(c *context.Context, u *database.User, remember bool) {
 func LoginPost(c *context.Context, f form.SignIn) {
 	c.Title("sign_in")
 
-	loginSources, err := database.LoginSources.List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
+	loginSources, err := database.Handle.LoginSources().List(c.Req.Context(), database.ListLoginSourceOptions{OnlyActivated: true})
 	if err != nil {
 		c.Error(err, "list activated login sources")
 		return

+ 3 - 3
mockgen.yaml

@@ -25,10 +25,10 @@ mocks:
     sources:
       - path: gogs.io/gogs/internal/database
         interfaces:
-          - LoginSourcesStore
-          - LoginSourceFilesStore
-          - LoginSourceFileStore
           - loginSourceFileStore
+          - loginSourceFilesStore
+  - filename: internal/database/mocks_gen.go
+    sources:
       - path: gogs.io/gogs/internal/auth
         interfaces:
           - Provider

Some files were not shown because too many files changed in this diff