Procházet zdrojové kódy

refactor(db): migrate methods off `user.go` (#7228)

Joe Chen před 2 roky
rodič
revize
fd798b4197

+ 9 - 7
internal/db/org.go

@@ -5,6 +5,7 @@
 package db
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"os"
@@ -104,18 +105,19 @@ func CreateOrganization(org, owner *User) (err error) {
 		return err
 	}
 
-	isExist, err := IsUserExist(0, org.Name)
-	if err != nil {
-		return err
-	} else if isExist {
-		return ErrUserAlreadyExist{args: errutil.Args{"name": org.Name}}
+	if Users.IsUsernameUsed(context.TODO(), org.Name) {
+		return ErrUserAlreadyExist{
+			args: errutil.Args{
+				"name": org.Name,
+			},
+		}
 	}
 
 	org.LowerName = strings.ToLower(org.Name)
-	if org.Rands, err = GetUserSalt(); err != nil {
+	if org.Rands, err = userutil.RandomSalt(); err != nil {
 		return err
 	}
-	if org.Salt, err = GetUserSalt(); err != nil {
+	if org.Salt, err = userutil.RandomSalt(); err != nil {
 		return err
 	}
 	org.UseCustomAvatar = true

+ 7 - 38
internal/db/user.go

@@ -24,7 +24,6 @@ import (
 	"gogs.io/gogs/internal/conf"
 	"gogs.io/gogs/internal/db/errors"
 	"gogs.io/gogs/internal/errutil"
-	"gogs.io/gogs/internal/strutil"
 	"gogs.io/gogs/internal/tool"
 	"gogs.io/gogs/internal/userutil"
 )
@@ -55,36 +54,11 @@ func (u *User) AfterSet(colName string, _ xorm.Cell) {
 
 // Deprecated: Use OrgsUsers.CountByUser instead.
 //
-// TODO(unknwon): Delete me once no more call sites.
+// TODO(unknwon): Delete me once no more call sites in this file.
 func (u *User) getOrganizationCount(e Engine) (int64, error) {
 	return e.Where("uid=?", u.ID).Count(new(OrgUser))
 }
 
-// IsUserExist checks if given user name exist,
-// the user name should be noncased unique.
-// If uid is presented, then check will rule out that one,
-// it is used when update a user name in settings page.
-func IsUserExist(uid int64, name string) (bool, error) {
-	if name == "" {
-		return false, nil
-	}
-	return x.Where("id != ?", uid).Get(&User{LowerName: strings.ToLower(name)})
-}
-
-// GetUserSalt returns a random user salt token.
-func GetUserSalt() (string, error) {
-	return strutil.RandomChars(10)
-}
-
-// NewGhostUser creates and returns a fake user for someone who has deleted his/her account.
-func NewGhostUser() *User {
-	return &User{
-		ID:        -1,
-		Name:      "Ghost",
-		LowerName: "ghost",
-	}
-}
-
 var (
 	reservedUsernames    = []string{"-", "explore", "create", "assets", "css", "img", "js", "less", "plugins", "debug", "raw", "install", "api", "avatar", "user", "org", "help", "stars", "issues", "pulls", "commits", "repo", "template", "admin", "new", ".", ".."}
 	reservedUserPatterns = []string{"*.keys"}
@@ -148,21 +122,19 @@ func isUsernameAllowed(name string) error {
 }
 
 // CreateUser creates record of a new user.
+//
 // Deprecated: Use Users.Create instead.
 func CreateUser(u *User) (err error) {
 	if err = isUsernameAllowed(u.Name); err != nil {
 		return err
 	}
 
-	isExist, err := IsUserExist(0, u.Name)
-	if err != nil {
-		return err
-	} else if isExist {
+	if Users.IsUsernameUsed(context.TODO(), u.Name) {
 		return ErrUserAlreadyExist{args: errutil.Args{"name": u.Name}}
 	}
 
 	u.Email = strings.ToLower(u.Email)
-	isExist, err = IsEmailUsed(u.Email)
+	isExist, err := IsEmailUsed(u.Email)
 	if err != nil {
 		return err
 	} else if isExist {
@@ -172,10 +144,10 @@ func CreateUser(u *User) (err error) {
 	u.LowerName = strings.ToLower(u.Name)
 	u.AvatarEmail = u.Email
 	u.Avatar = tool.HashEmail(u.AvatarEmail)
-	if u.Rands, err = GetUserSalt(); err != nil {
+	if u.Rands, err = userutil.RandomSalt(); err != nil {
 		return err
 	}
-	if u.Salt, err = GetUserSalt(); err != nil {
+	if u.Salt, err = userutil.RandomSalt(); err != nil {
 		return err
 	}
 	u.Password = userutil.EncodePassword(u.Password, u.Salt)
@@ -273,10 +245,7 @@ func ChangeUserName(u *User, newUserName string) (err error) {
 		return err
 	}
 
-	isExist, err := IsUserExist(0, newUserName)
-	if err != nil {
-		return err
-	} else if isExist {
+	if Users.IsUsernameUsed(context.TODO(), newUserName) {
 		return ErrUserAlreadyExist{args: errutil.Args{"name": newUserName}}
 	}
 

+ 2 - 1
internal/db/user_mail.go

@@ -10,6 +10,7 @@ import (
 
 	"gogs.io/gogs/internal/db/errors"
 	"gogs.io/gogs/internal/errutil"
+	"gogs.io/gogs/internal/userutil"
 )
 
 // EmailAddresses is the list of all email addresses of a user. Can contain the
@@ -122,7 +123,7 @@ func (email *EmailAddress) Activate() error {
 	if err != nil {
 		return err
 	}
-	if user.Rands, err = GetUserSalt(); err != nil {
+	if user.Rands, err = userutil.RandomSalt(); err != nil {
 		return err
 	}
 

+ 26 - 6
internal/db/users.go

@@ -62,6 +62,8 @@ type UsersStore interface {
 	GetByUsername(ctx context.Context, username string) (*User, error)
 	// HasForkedRepository returns true if the user has forked given repository.
 	HasForkedRepository(ctx context.Context, userID, repoID int64) bool
+	// IsUsernameUsed returns true if the given username has been used.
+	IsUsernameUsed(ctx context.Context, username string) bool
 	// ListFollowers returns a list of users that are following the given user.
 	// Results are paginated by given page and page size, and sorted by the time of
 	// follow in descending order.
@@ -231,11 +233,8 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
 		return nil, err
 	}
 
-	_, err = db.GetByUsername(ctx, username)
-	if err == nil {
+	if db.IsUsernameUsed(ctx, username) {
 		return nil, ErrUserAlreadyExist{args: errutil.Args{"name": username}}
-	} else if !IsErrUserNotExist(err) {
-		return nil, err
 	}
 
 	_, err = db.GetByEmail(ctx, email)
@@ -262,11 +261,11 @@ func (db *users) Create(ctx context.Context, username, email string, opts Create
 		AvatarEmail:     email,
 	}
 
-	user.Rands, err = GetUserSalt()
+	user.Rands, err = userutil.RandomSalt()
 	if err != nil {
 		return nil, err
 	}
-	user.Salt, err = GetUserSalt()
+	user.Salt, err = userutil.RandomSalt()
 	if err != nil {
 		return nil, err
 	}
@@ -371,6 +370,17 @@ func (db *users) HasForkedRepository(ctx context.Context, userID, repoID int64)
 	return count > 0
 }
 
+func (db *users) IsUsernameUsed(ctx context.Context, username string) bool {
+	if username == "" {
+		return false
+	}
+	return db.WithContext(ctx).
+		Select("id").
+		Where("lower_name = ?", strings.ToLower(username)).
+		First(&User{}).
+		Error != gorm.ErrRecordNotFound
+}
+
 func (db *users) ListFollowers(ctx context.Context, userID int64, page, pageSize int) ([]*User, error) {
 	/*
 		Equivalent SQL for PostgreSQL:
@@ -569,6 +579,16 @@ func (u *User) DisplayName() string {
 	return u.Name
 }
 
+// NewGhostUser creates and returns a fake user for people who has deleted their
+// accounts.
+func NewGhostUser() *User {
+	return &User{
+		ID:        -1,
+		Name:      "Ghost",
+		LowerName: "ghost",
+	}
+}
+
 // HomeURLPath returns the URL path to the user or organization home page.
 //
 // TODO(unknwon): This is also used in templates, which should be fixed by

+ 13 - 0
internal/db/users_test.go

@@ -93,6 +93,7 @@ func TestUsers(t *testing.T) {
 		{"GetByID", usersGetByID},
 		{"GetByUsername", usersGetByUsername},
 		{"HasForkedRepository", usersHasForkedRepository},
+		{"IsUsernameUsed", usersIsUsernameUsed},
 		{"ListFollowers", usersListFollowers},
 		{"ListFollowings", usersListFollowings},
 		{"UseCustomAvatar", usersUseCustomAvatar},
@@ -390,6 +391,18 @@ func usersHasForkedRepository(t *testing.T, db *users) {
 	assert.True(t, has)
 }
 
+func usersIsUsernameUsed(t *testing.T, db *users) {
+	ctx := context.Background()
+
+	alice, err := db.Create(ctx, "alice", "[email protected]", CreateUserOptions{})
+	require.NoError(t, err)
+
+	got := db.IsUsernameUsed(ctx, alice.Name)
+	assert.True(t, got)
+	got = db.IsUsernameUsed(ctx, "bob")
+	assert.False(t, got)
+}
+
 func usersListFollowers(t *testing.T, db *users) {
 	ctx := context.Background()
 

+ 1 - 1
internal/route/admin/users.go

@@ -189,7 +189,7 @@ func EditUserPost(c *context.Context, f form.AdminEditUser) {
 	if len(f.Password) > 0 {
 		u.Password = f.Password
 		var err error
-		if u.Salt, err = db.GetUserSalt(); err != nil {
+		if u.Salt, err = userutil.RandomSalt(); err != nil {
 			c.Error(err, "get user salt")
 			return
 		}

+ 1 - 1
internal/route/api/v1/admin/user.go

@@ -85,7 +85,7 @@ func EditUser(c *context.APIContext, form api.EditUserOption) {
 	if len(form.Password) > 0 {
 		u.Password = form.Password
 		var err error
-		if u.Salt, err = db.GetUserSalt(); err != nil {
+		if u.Salt, err = userutil.RandomSalt(); err != nil {
 			c.Error(err, "get user salt")
 			return
 		}

+ 121 - 0
internal/route/lfs/mocks_test.go

@@ -2313,6 +2313,9 @@ type MockUsersStore struct {
 	// HasForkedRepositoryFunc is an instance of a mock function object
 	// controlling the behavior of the method HasForkedRepository.
 	HasForkedRepositoryFunc *UsersStoreHasForkedRepositoryFunc
+	// IsUsernameUsedFunc is an instance of a mock function object
+	// controlling the behavior of the method IsUsernameUsed.
+	IsUsernameUsedFunc *UsersStoreIsUsernameUsedFunc
 	// ListFollowersFunc is an instance of a mock function object
 	// controlling the behavior of the method ListFollowers.
 	ListFollowersFunc *UsersStoreListFollowersFunc
@@ -2363,6 +2366,11 @@ func NewMockUsersStore() *MockUsersStore {
 				return
 			},
 		},
+		IsUsernameUsedFunc: &UsersStoreIsUsernameUsedFunc{
+			defaultHook: func(context.Context, string) (r0 bool) {
+				return
+			},
+		},
 		ListFollowersFunc: &UsersStoreListFollowersFunc{
 			defaultHook: func(context.Context, int64, int, int) (r0 []*db.User, r1 error) {
 				return
@@ -2420,6 +2428,11 @@ func NewStrictMockUsersStore() *MockUsersStore {
 				panic("unexpected invocation of MockUsersStore.HasForkedRepository")
 			},
 		},
+		IsUsernameUsedFunc: &UsersStoreIsUsernameUsedFunc{
+			defaultHook: func(context.Context, string) bool {
+				panic("unexpected invocation of MockUsersStore.IsUsernameUsed")
+			},
+		},
 		ListFollowersFunc: &UsersStoreListFollowersFunc{
 			defaultHook: func(context.Context, int64, int, int) ([]*db.User, error) {
 				panic("unexpected invocation of MockUsersStore.ListFollowers")
@@ -2463,6 +2476,9 @@ func NewMockUsersStoreFrom(i db.UsersStore) *MockUsersStore {
 		HasForkedRepositoryFunc: &UsersStoreHasForkedRepositoryFunc{
 			defaultHook: i.HasForkedRepository,
 		},
+		IsUsernameUsedFunc: &UsersStoreIsUsernameUsedFunc{
+			defaultHook: i.IsUsernameUsed,
+		},
 		ListFollowersFunc: &UsersStoreListFollowersFunc{
 			defaultHook: i.ListFollowers,
 		},
@@ -3242,6 +3258,111 @@ func (c UsersStoreHasForkedRepositoryFuncCall) Results() []interface{} {
 	return []interface{}{c.Result0}
 }
 
+// UsersStoreIsUsernameUsedFunc describes the behavior when the
+// IsUsernameUsed method of the parent MockUsersStore instance is invoked.
+type UsersStoreIsUsernameUsedFunc struct {
+	defaultHook func(context.Context, string) bool
+	hooks       []func(context.Context, string) bool
+	history     []UsersStoreIsUsernameUsedFuncCall
+	mutex       sync.Mutex
+}
+
+// IsUsernameUsed delegates to the next hook function in the queue and
+// stores the parameter and result values of this invocation.
+func (m *MockUsersStore) IsUsernameUsed(v0 context.Context, v1 string) bool {
+	r0 := m.IsUsernameUsedFunc.nextHook()(v0, v1)
+	m.IsUsernameUsedFunc.appendCall(UsersStoreIsUsernameUsedFuncCall{v0, v1, r0})
+	return r0
+}
+
+// SetDefaultHook sets function that is called when the IsUsernameUsed
+// method of the parent MockUsersStore instance is invoked and the hook
+// queue is empty.
+func (f *UsersStoreIsUsernameUsedFunc) SetDefaultHook(hook func(context.Context, string) bool) {
+	f.defaultHook = hook
+}
+
+// PushHook adds a function to the end of hook queue. Each invocation of the
+// IsUsernameUsed method of the parent MockUsersStore instance invokes the
+// hook at the front of the queue and discards it. After the queue is empty,
+// the default hook function is invoked for any future action.
+func (f *UsersStoreIsUsernameUsedFunc) PushHook(hook func(context.Context, string) bool) {
+	f.mutex.Lock()
+	f.hooks = append(f.hooks, hook)
+	f.mutex.Unlock()
+}
+
+// SetDefaultReturn calls SetDefaultHook with a function that returns the
+// given values.
+func (f *UsersStoreIsUsernameUsedFunc) SetDefaultReturn(r0 bool) {
+	f.SetDefaultHook(func(context.Context, string) bool {
+		return r0
+	})
+}
+
+// PushReturn calls PushHook with a function that returns the given values.
+func (f *UsersStoreIsUsernameUsedFunc) PushReturn(r0 bool) {
+	f.PushHook(func(context.Context, string) bool {
+		return r0
+	})
+}
+
+func (f *UsersStoreIsUsernameUsedFunc) nextHook() func(context.Context, string) bool {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+
+	if len(f.hooks) == 0 {
+		return f.defaultHook
+	}
+
+	hook := f.hooks[0]
+	f.hooks = f.hooks[1:]
+	return hook
+}
+
+func (f *UsersStoreIsUsernameUsedFunc) appendCall(r0 UsersStoreIsUsernameUsedFuncCall) {
+	f.mutex.Lock()
+	f.history = append(f.history, r0)
+	f.mutex.Unlock()
+}
+
+// History returns a sequence of UsersStoreIsUsernameUsedFuncCall objects
+// describing the invocations of this function.
+func (f *UsersStoreIsUsernameUsedFunc) History() []UsersStoreIsUsernameUsedFuncCall {
+	f.mutex.Lock()
+	history := make([]UsersStoreIsUsernameUsedFuncCall, len(f.history))
+	copy(history, f.history)
+	f.mutex.Unlock()
+
+	return history
+}
+
+// UsersStoreIsUsernameUsedFuncCall is an object that describes an
+// invocation of method IsUsernameUsed on an instance of MockUsersStore.
+type UsersStoreIsUsernameUsedFuncCall struct {
+	// Arg0 is the value of the 1st argument passed to this method
+	// invocation.
+	Arg0 context.Context
+	// Arg1 is the value of the 2nd argument passed to this method
+	// invocation.
+	Arg1 string
+	// Result0 is the value of the 1st result returned from this method
+	// invocation.
+	Result0 bool
+}
+
+// Args returns an interface slice containing the arguments of this
+// invocation.
+func (c UsersStoreIsUsernameUsedFuncCall) Args() []interface{} {
+	return []interface{}{c.Arg0, c.Arg1}
+}
+
+// Results returns an interface slice containing the results of this
+// invocation.
+func (c UsersStoreIsUsernameUsedFuncCall) Results() []interface{} {
+	return []interface{}{c.Result0}
+}
+
 // UsersStoreListFollowersFunc describes the behavior when the ListFollowers
 // method of the parent MockUsersStore instance is invoked.
 type UsersStoreListFollowersFunc struct {

+ 2 - 6
internal/route/org/setting.go

@@ -41,15 +41,11 @@ func SettingsPost(c *context.Context, f form.UpdateOrgSetting) {
 
 	// Check if organization name has been changed.
 	if org.LowerName != strings.ToLower(f.Name) {
-		isExist, err := db.IsUserExist(org.ID, f.Name)
-		if err != nil {
-			c.Error(err, "check if user exists")
-			return
-		} else if isExist {
+		if db.Users.IsUsernameUsed(c.Req.Context(), f.Name) {
 			c.Data["OrgName"] = true
 			c.RenderWithErr(c.Tr("form.username_been_taken"), SETTINGS_OPTIONS, &f)
 			return
-		} else if err = db.ChangeUserName(org, f.Name); err != nil {
+		} else if err := db.ChangeUserName(org, f.Name); err != nil {
 			c.Data["OrgName"] = true
 			switch {
 			case db.IsErrNameNotAllowed(err):

+ 2 - 6
internal/route/repo/setting.go

@@ -225,16 +225,12 @@ func SettingsPost(c *context.Context, f form.RepoSetting) {
 		}
 
 		newOwner := c.Query("new_owner_name")
-		isExist, err := db.IsUserExist(0, newOwner)
-		if err != nil {
-			c.Error(err, "check if user exists")
-			return
-		} else if !isExist {
+		if !db.Users.IsUsernameUsed(c.Req.Context(), newOwner) {
 			c.RenderWithErr(c.Tr("form.enterred_invalid_owner_name"), SETTINGS_OPTIONS, nil)
 			return
 		}
 
-		if err = db.TransferOwnership(c.User, newOwner, repo); err != nil {
+		if err := db.TransferOwnership(c.User, newOwner, repo); err != nil {
 			if db.IsErrRepoAlreadyExist(err) {
 				c.RenderWithErr(c.Tr("repo.settings.new_owner_has_same_repo"), SETTINGS_OPTIONS, nil)
 			} else {

+ 3 - 3
internal/route/user/auth.go

@@ -414,7 +414,7 @@ func Activate(c *context.Context) {
 	if user := db.VerifyUserActiveCode(code); user != nil {
 		user.IsActive = true
 		var err error
-		if user.Rands, err = db.GetUserSalt(); err != nil {
+		if user.Rands, err = userutil.RandomSalt(); err != nil {
 			c.Error(err, "get user salt")
 			return
 		}
@@ -547,11 +547,11 @@ func ResetPasswdPost(c *context.Context) {
 
 		u.Password = passwd
 		var err error
-		if u.Rands, err = db.GetUserSalt(); err != nil {
+		if u.Rands, err = userutil.RandomSalt(); err != nil {
 			c.Error(err, "get user salt")
 			return
 		}
-		if u.Salt, err = db.GetUserSalt(); err != nil {
+		if u.Salt, err = userutil.RandomSalt(); err != nil {
 			c.Error(err, "get user salt")
 			return
 		}

+ 1 - 1
internal/route/user/setting.go

@@ -198,7 +198,7 @@ func SettingsPasswordPost(c *context.Context, f form.ChangePassword) {
 	} else {
 		c.User.Password = f.Password
 		var err error
-		if c.User.Salt, err = db.GetUserSalt(); err != nil {
+		if c.User.Salt, err = userutil.RandomSalt(); err != nil {
 			c.Errorf(err, "get user salt")
 			return
 		}

+ 7 - 0
internal/userutil/userutil.go

@@ -23,6 +23,7 @@ import (
 
 	"gogs.io/gogs/internal/avatar"
 	"gogs.io/gogs/internal/conf"
+	"gogs.io/gogs/internal/strutil"
 	"gogs.io/gogs/internal/tool"
 )
 
@@ -132,3 +133,9 @@ func MailResendCacheKey(userID int64) string {
 func TwoFactorCacheKey(userID int64, passcode string) string {
 	return fmt.Sprintf("twoFactor::%d::%s", userID, passcode)
 }
+
+// RandomSalt returns randomly generated 10-character string that can be used as
+// the user salt.
+func RandomSalt() (string, error) {
+	return strutil.RandomChars(10)
+}

+ 8 - 0
internal/userutil/userutil_test.go

@@ -191,3 +191,11 @@ func TestTwoFactorCacheKey(t *testing.T) {
 	got := TwoFactorCacheKey(1, "113654")
 	assert.Equal(t, "twoFactor::1::113654", got)
 }
+
+func TestRandomSalt(t *testing.T) {
+	salt1, err := RandomSalt()
+	require.NoError(t, err)
+	salt2, err := RandomSalt()
+	require.NoError(t, err)
+	assert.NotEqual(t, salt1, salt2)
+}