login_sources.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "fmt"
  7. "strconv"
  8. "time"
  9. jsoniter "github.com/json-iterator/go"
  10. "github.com/pkg/errors"
  11. "gorm.io/gorm"
  12. "gogs.io/gogs/internal/auth"
  13. "gogs.io/gogs/internal/auth/github"
  14. "gogs.io/gogs/internal/auth/ldap"
  15. "gogs.io/gogs/internal/auth/pam"
  16. "gogs.io/gogs/internal/auth/smtp"
  17. "gogs.io/gogs/internal/errutil"
  18. )
  19. // LoginSourcesStore is the persistent interface for login sources.
  20. //
  21. // NOTE: All methods are sorted in alphabetical order.
  22. type LoginSourcesStore interface {
  23. // Create creates a new login source and persist to database.
  24. // It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
  25. Create(opts CreateLoginSourceOpts) (*LoginSource, error)
  26. // Count returns the total number of login sources.
  27. Count() int64
  28. // DeleteByID deletes a login source by given ID.
  29. // It returns ErrLoginSourceInUse if at least one user is associated with the login source.
  30. DeleteByID(id int64) error
  31. // GetByID returns the login source with given ID.
  32. // It returns ErrLoginSourceNotExist when not found.
  33. GetByID(id int64) (*LoginSource, error)
  34. // List returns a list of login sources filtered by options.
  35. List(opts ListLoginSourceOpts) ([]*LoginSource, error)
  36. // ResetNonDefault clears default flag for all the other login sources.
  37. ResetNonDefault(source *LoginSource) error
  38. // Save persists all values of given login source to database or local file.
  39. // The Updated field is set to current time automatically.
  40. Save(t *LoginSource) error
  41. }
  42. var LoginSources LoginSourcesStore
  43. // LoginSource represents an external way for authorizing users.
  44. type LoginSource struct {
  45. ID int64
  46. Type auth.Type
  47. Name string `xorm:"UNIQUE" gorm:"UNIQUE"`
  48. IsActived bool `xorm:"NOT NULL DEFAULT false" gorm:"NOT NULL"`
  49. IsDefault bool `xorm:"DEFAULT false"`
  50. Provider auth.Provider `xorm:"-" gorm:"-"`
  51. Config string `xorm:"TEXT cfg" gorm:"COLUMN:cfg;TYPE:TEXT" json:"RawConfig"`
  52. Created time.Time `xorm:"-" gorm:"-" json:"-"`
  53. CreatedUnix int64
  54. Updated time.Time `xorm:"-" gorm:"-" json:"-"`
  55. UpdatedUnix int64
  56. File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
  57. }
  58. // NOTE: This is a GORM save hook.
  59. func (s *LoginSource) BeforeSave(_ *gorm.DB) (err error) {
  60. if s.Provider == nil {
  61. return nil
  62. }
  63. s.Config, err = jsoniter.MarshalToString(s.Provider.Config())
  64. return err
  65. }
  66. // NOTE: This is a GORM create hook.
  67. func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
  68. if s.CreatedUnix == 0 {
  69. s.CreatedUnix = tx.NowFunc().Unix()
  70. s.UpdatedUnix = s.CreatedUnix
  71. }
  72. return nil
  73. }
  74. // NOTE: This is a GORM update hook.
  75. func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
  76. s.UpdatedUnix = tx.NowFunc().Unix()
  77. return nil
  78. }
  79. // NOTE: This is a GORM query hook.
  80. func (s *LoginSource) AfterFind(_ *gorm.DB) error {
  81. s.Created = time.Unix(s.CreatedUnix, 0).Local()
  82. s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
  83. switch s.Type {
  84. case auth.LDAP:
  85. var cfg ldap.Config
  86. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  87. if err != nil {
  88. return err
  89. }
  90. s.Provider = ldap.NewProvider(false, &cfg)
  91. case auth.DLDAP:
  92. var cfg ldap.Config
  93. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  94. if err != nil {
  95. return err
  96. }
  97. s.Provider = ldap.NewProvider(true, &cfg)
  98. case auth.SMTP:
  99. var cfg smtp.Config
  100. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  101. if err != nil {
  102. return err
  103. }
  104. s.Provider = smtp.NewProvider(&cfg)
  105. case auth.PAM:
  106. var cfg pam.Config
  107. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  108. if err != nil {
  109. return err
  110. }
  111. s.Provider = pam.NewProvider(&cfg)
  112. case auth.GitHub:
  113. var cfg github.Config
  114. err := jsoniter.UnmarshalFromString(s.Config, &cfg)
  115. if err != nil {
  116. return err
  117. }
  118. s.Provider = github.NewProvider(&cfg)
  119. default:
  120. return fmt.Errorf("unrecognized login source type: %v", s.Type)
  121. }
  122. return nil
  123. }
  124. func (s *LoginSource) TypeName() string {
  125. return auth.Name(s.Type)
  126. }
  127. func (s *LoginSource) IsLDAP() bool {
  128. return s.Type == auth.LDAP
  129. }
  130. func (s *LoginSource) IsDLDAP() bool {
  131. return s.Type == auth.DLDAP
  132. }
  133. func (s *LoginSource) IsSMTP() bool {
  134. return s.Type == auth.SMTP
  135. }
  136. func (s *LoginSource) IsPAM() bool {
  137. return s.Type == auth.PAM
  138. }
  139. func (s *LoginSource) IsGitHub() bool {
  140. return s.Type == auth.GitHub
  141. }
  142. func (s *LoginSource) LDAP() *ldap.Config {
  143. return s.Provider.Config().(*ldap.Config)
  144. }
  145. func (s *LoginSource) SMTP() *smtp.Config {
  146. return s.Provider.Config().(*smtp.Config)
  147. }
  148. func (s *LoginSource) PAM() *pam.Config {
  149. return s.Provider.Config().(*pam.Config)
  150. }
  151. func (s *LoginSource) GitHub() *github.Config {
  152. return s.Provider.Config().(*github.Config)
  153. }
  154. var _ LoginSourcesStore = (*loginSources)(nil)
  155. type loginSources struct {
  156. *gorm.DB
  157. files loginSourceFilesStore
  158. }
  159. type CreateLoginSourceOpts struct {
  160. Type auth.Type
  161. Name string
  162. Activated bool
  163. Default bool
  164. Config interface{}
  165. }
  166. type ErrLoginSourceAlreadyExist struct {
  167. args errutil.Args
  168. }
  169. func IsErrLoginSourceAlreadyExist(err error) bool {
  170. _, ok := err.(ErrLoginSourceAlreadyExist)
  171. return ok
  172. }
  173. func (err ErrLoginSourceAlreadyExist) Error() string {
  174. return fmt.Sprintf("login source already exists: %v", err.args)
  175. }
  176. func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
  177. err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
  178. if err == nil {
  179. return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
  180. } else if err != gorm.ErrRecordNotFound {
  181. return nil, err
  182. }
  183. source := &LoginSource{
  184. Type: opts.Type,
  185. Name: opts.Name,
  186. IsActived: opts.Activated,
  187. IsDefault: opts.Default,
  188. }
  189. source.Config, err = jsoniter.MarshalToString(opts.Config)
  190. if err != nil {
  191. return nil, err
  192. }
  193. return source, db.DB.Create(source).Error
  194. }
  195. func (db *loginSources) Count() int64 {
  196. var count int64
  197. db.Model(new(LoginSource)).Count(&count)
  198. return count + int64(db.files.Len())
  199. }
  200. type ErrLoginSourceInUse struct {
  201. args errutil.Args
  202. }
  203. func IsErrLoginSourceInUse(err error) bool {
  204. _, ok := err.(ErrLoginSourceInUse)
  205. return ok
  206. }
  207. func (err ErrLoginSourceInUse) Error() string {
  208. return fmt.Sprintf("login source is still used by some users: %v", err.args)
  209. }
  210. func (db *loginSources) DeleteByID(id int64) error {
  211. var count int64
  212. err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
  213. if err != nil {
  214. return err
  215. } else if count > 0 {
  216. return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
  217. }
  218. return db.Where("id = ?", id).Delete(new(LoginSource)).Error
  219. }
  220. func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
  221. source := new(LoginSource)
  222. err := db.Where("id = ?", id).First(source).Error
  223. if err != nil {
  224. if err == gorm.ErrRecordNotFound {
  225. return db.files.GetByID(id)
  226. }
  227. return nil, err
  228. }
  229. return source, nil
  230. }
  231. type ListLoginSourceOpts struct {
  232. // Whether to only include activated login sources.
  233. OnlyActivated bool
  234. }
  235. func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
  236. var sources []*LoginSource
  237. query := db.Order("id ASC")
  238. if opts.OnlyActivated {
  239. query = query.Where("is_actived = ?", true)
  240. }
  241. err := query.Find(&sources).Error
  242. if err != nil {
  243. return nil, err
  244. }
  245. return append(sources, db.files.List(opts)...), nil
  246. }
  247. func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
  248. err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
  249. if err != nil {
  250. return err
  251. }
  252. for _, source := range db.files.List(ListLoginSourceOpts{}) {
  253. if source.File != nil && source.ID != dflt.ID {
  254. source.File.SetGeneral("is_default", "false")
  255. if err = source.File.Save(); err != nil {
  256. return errors.Wrap(err, "save file")
  257. }
  258. }
  259. }
  260. db.files.Update(dflt)
  261. return nil
  262. }
  263. func (db *loginSources) Save(source *LoginSource) error {
  264. if source.File == nil {
  265. return db.DB.Save(source).Error
  266. }
  267. source.File.SetGeneral("name", source.Name)
  268. source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
  269. source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
  270. if err := source.File.SetConfig(source.Provider.Config()); err != nil {
  271. return errors.Wrap(err, "set config")
  272. } else if err = source.File.Save(); err != nil {
  273. return errors.Wrap(err, "save file")
  274. }
  275. return nil
  276. }