login_sources_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. mockrequire "github.com/derision-test/go-mockgen/testutil/require"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "gorm.io/gorm"
  13. "gogs.io/gogs/internal/auth"
  14. "gogs.io/gogs/internal/auth/github"
  15. "gogs.io/gogs/internal/auth/ldap"
  16. "gogs.io/gogs/internal/auth/pam"
  17. "gogs.io/gogs/internal/auth/smtp"
  18. "gogs.io/gogs/internal/dbtest"
  19. "gogs.io/gogs/internal/errutil"
  20. )
  21. func TestLoginSource_BeforeSave(t *testing.T) {
  22. now := time.Now()
  23. db := &gorm.DB{
  24. Config: &gorm.Config{
  25. SkipDefaultTransaction: true,
  26. NowFunc: func() time.Time {
  27. return now
  28. },
  29. },
  30. }
  31. t.Run("Config has not been set", func(t *testing.T) {
  32. s := &LoginSource{}
  33. err := s.BeforeSave(db)
  34. require.NoError(t, err)
  35. assert.Empty(t, s.Config)
  36. })
  37. t.Run("Config has been set", func(t *testing.T) {
  38. s := &LoginSource{
  39. Provider: pam.NewProvider(&pam.Config{
  40. ServiceName: "pam_service",
  41. }),
  42. }
  43. err := s.BeforeSave(db)
  44. require.NoError(t, err)
  45. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
  46. })
  47. }
  48. func TestLoginSource_BeforeCreate(t *testing.T) {
  49. now := time.Now()
  50. db := &gorm.DB{
  51. Config: &gorm.Config{
  52. SkipDefaultTransaction: true,
  53. NowFunc: func() time.Time {
  54. return now
  55. },
  56. },
  57. }
  58. t.Run("CreatedUnix has been set", func(t *testing.T) {
  59. s := &LoginSource{
  60. CreatedUnix: 1,
  61. }
  62. _ = s.BeforeCreate(db)
  63. assert.Equal(t, int64(1), s.CreatedUnix)
  64. assert.Equal(t, int64(0), s.UpdatedUnix)
  65. })
  66. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  67. s := &LoginSource{}
  68. _ = s.BeforeCreate(db)
  69. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  70. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  71. })
  72. }
  73. func TestLoginSource_BeforeUpdate(t *testing.T) {
  74. now := time.Now()
  75. db := &gorm.DB{
  76. Config: &gorm.Config{
  77. SkipDefaultTransaction: true,
  78. NowFunc: func() time.Time {
  79. return now
  80. },
  81. },
  82. }
  83. s := &LoginSource{}
  84. _ = s.BeforeUpdate(db)
  85. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  86. }
  87. func TestLoginSource_AfterFind(t *testing.T) {
  88. now := time.Now()
  89. db := &gorm.DB{
  90. Config: &gorm.Config{
  91. SkipDefaultTransaction: true,
  92. NowFunc: func() time.Time {
  93. return now
  94. },
  95. },
  96. }
  97. tests := []struct {
  98. name string
  99. authType auth.Type
  100. wantType any
  101. }{
  102. {
  103. name: "LDAP",
  104. authType: auth.LDAP,
  105. wantType: &ldap.Provider{},
  106. },
  107. {
  108. name: "DLDAP",
  109. authType: auth.DLDAP,
  110. wantType: &ldap.Provider{},
  111. },
  112. {
  113. name: "SMTP",
  114. authType: auth.SMTP,
  115. wantType: &smtp.Provider{},
  116. },
  117. {
  118. name: "PAM",
  119. authType: auth.PAM,
  120. wantType: &pam.Provider{},
  121. },
  122. {
  123. name: "GitHub",
  124. authType: auth.GitHub,
  125. wantType: &github.Provider{},
  126. },
  127. }
  128. for _, test := range tests {
  129. t.Run(test.name, func(t *testing.T) {
  130. s := LoginSource{
  131. Type: test.authType,
  132. Config: `{}`,
  133. CreatedUnix: now.Unix(),
  134. UpdatedUnix: now.Unix(),
  135. }
  136. err := s.AfterFind(db)
  137. require.NoError(t, err)
  138. assert.Equal(t, s.CreatedUnix, s.Created.Unix())
  139. assert.Equal(t, s.UpdatedUnix, s.Updated.Unix())
  140. assert.IsType(t, test.wantType, s.Provider)
  141. })
  142. }
  143. }
  144. func TestLoginSources(t *testing.T) {
  145. if testing.Short() {
  146. t.Skip()
  147. }
  148. t.Parallel()
  149. tables := []interface{}{new(LoginSource), new(User)}
  150. db := &loginSources{
  151. DB: dbtest.NewDB(t, "loginSources", tables...),
  152. }
  153. for _, tc := range []struct {
  154. name string
  155. test func(t *testing.T, db *loginSources)
  156. }{
  157. {"Create", loginSourcesCreate},
  158. {"Count", loginSourcesCount},
  159. {"DeleteByID", loginSourcesDeleteByID},
  160. {"GetByID", loginSourcesGetByID},
  161. {"List", loginSourcesList},
  162. {"ResetNonDefault", loginSourcesResetNonDefault},
  163. {"Save", loginSourcesSave},
  164. } {
  165. t.Run(tc.name, func(t *testing.T) {
  166. t.Cleanup(func() {
  167. err := clearTables(t, db.DB, tables...)
  168. require.NoError(t, err)
  169. })
  170. tc.test(t, db)
  171. })
  172. if t.Failed() {
  173. break
  174. }
  175. }
  176. }
  177. func loginSourcesCreate(t *testing.T, db *loginSources) {
  178. ctx := context.Background()
  179. // Create first login source with name "GitHub"
  180. source, err := db.Create(ctx,
  181. CreateLoginSourceOptions{
  182. Type: auth.GitHub,
  183. Name: "GitHub",
  184. Activated: true,
  185. Default: false,
  186. Config: &github.Config{
  187. APIEndpoint: "https://api.github.com",
  188. },
  189. },
  190. )
  191. require.NoError(t, err)
  192. // Get it back and check the Created field
  193. source, err = db.GetByID(ctx, source.ID)
  194. require.NoError(t, err)
  195. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  196. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  197. // Try create second login source with same name should fail
  198. _, err = db.Create(ctx, CreateLoginSourceOptions{Name: source.Name})
  199. wantErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  200. assert.Equal(t, wantErr, err)
  201. }
  202. func loginSourcesCount(t *testing.T, db *loginSources) {
  203. ctx := context.Background()
  204. // Create two login sources, one in database and one as source file.
  205. _, err := db.Create(ctx,
  206. CreateLoginSourceOptions{
  207. Type: auth.GitHub,
  208. Name: "GitHub",
  209. Activated: true,
  210. Default: false,
  211. Config: &github.Config{
  212. APIEndpoint: "https://api.github.com",
  213. },
  214. },
  215. )
  216. require.NoError(t, err)
  217. mock := NewMockLoginSourceFilesStore()
  218. mock.LenFunc.SetDefaultReturn(2)
  219. setMockLoginSourceFilesStore(t, db, mock)
  220. assert.Equal(t, int64(3), db.Count(ctx))
  221. }
  222. func loginSourcesDeleteByID(t *testing.T, db *loginSources) {
  223. ctx := context.Background()
  224. t.Run("delete but in used", func(t *testing.T) {
  225. source, err := db.Create(ctx,
  226. CreateLoginSourceOptions{
  227. Type: auth.GitHub,
  228. Name: "GitHub",
  229. Activated: true,
  230. Default: false,
  231. Config: &github.Config{
  232. APIEndpoint: "https://api.github.com",
  233. },
  234. },
  235. )
  236. require.NoError(t, err)
  237. // Create a user that uses this login source
  238. _, err = (&users{DB: db.DB}).Create(ctx, "alice", "",
  239. CreateUserOptions{
  240. LoginSource: source.ID,
  241. },
  242. )
  243. require.NoError(t, err)
  244. // Delete the login source will result in error
  245. err = db.DeleteByID(ctx, source.ID)
  246. wantErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  247. assert.Equal(t, wantErr, err)
  248. })
  249. mock := NewMockLoginSourceFilesStore()
  250. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  251. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  252. })
  253. setMockLoginSourceFilesStore(t, db, mock)
  254. // Create a login source with name "GitHub2"
  255. source, err := db.Create(ctx,
  256. CreateLoginSourceOptions{
  257. Type: auth.GitHub,
  258. Name: "GitHub2",
  259. Activated: true,
  260. Default: false,
  261. Config: &github.Config{
  262. APIEndpoint: "https://api.github.com",
  263. },
  264. },
  265. )
  266. require.NoError(t, err)
  267. // Delete a non-existent ID is noop
  268. err = db.DeleteByID(ctx, 9999)
  269. require.NoError(t, err)
  270. // We should be able to get it back
  271. _, err = db.GetByID(ctx, source.ID)
  272. require.NoError(t, err)
  273. // Now delete this login source with ID
  274. err = db.DeleteByID(ctx, source.ID)
  275. require.NoError(t, err)
  276. // We should get token not found error
  277. _, err = db.GetByID(ctx, source.ID)
  278. wantErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  279. assert.Equal(t, wantErr, err)
  280. }
  281. func loginSourcesGetByID(t *testing.T, db *loginSources) {
  282. ctx := context.Background()
  283. mock := NewMockLoginSourceFilesStore()
  284. mock.GetByIDFunc.SetDefaultHook(func(id int64) (*LoginSource, error) {
  285. if id != 101 {
  286. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  287. }
  288. return &LoginSource{ID: id}, nil
  289. })
  290. setMockLoginSourceFilesStore(t, db, mock)
  291. expConfig := &github.Config{
  292. APIEndpoint: "https://api.github.com",
  293. }
  294. // Create a login source with name "GitHub"
  295. source, err := db.Create(ctx,
  296. CreateLoginSourceOptions{
  297. Type: auth.GitHub,
  298. Name: "GitHub",
  299. Activated: true,
  300. Default: false,
  301. Config: expConfig,
  302. },
  303. )
  304. require.NoError(t, err)
  305. // Get the one in the database and test the read/write hooks
  306. source, err = db.GetByID(ctx, source.ID)
  307. require.NoError(t, err)
  308. assert.Equal(t, expConfig, source.Provider.Config())
  309. // Get the one in source file store
  310. _, err = db.GetByID(ctx, 101)
  311. require.NoError(t, err)
  312. }
  313. func loginSourcesList(t *testing.T, db *loginSources) {
  314. ctx := context.Background()
  315. mock := NewMockLoginSourceFilesStore()
  316. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  317. if opts.OnlyActivated {
  318. return []*LoginSource{
  319. {ID: 1},
  320. }
  321. }
  322. return []*LoginSource{
  323. {ID: 1},
  324. {ID: 2},
  325. }
  326. })
  327. setMockLoginSourceFilesStore(t, db, mock)
  328. // Create two login sources in database, one activated and the other one not
  329. _, err := db.Create(ctx,
  330. CreateLoginSourceOptions{
  331. Type: auth.PAM,
  332. Name: "PAM",
  333. Config: &pam.Config{
  334. ServiceName: "PAM",
  335. },
  336. },
  337. )
  338. require.NoError(t, err)
  339. _, err = db.Create(ctx,
  340. CreateLoginSourceOptions{
  341. Type: auth.GitHub,
  342. Name: "GitHub",
  343. Activated: true,
  344. Config: &github.Config{
  345. APIEndpoint: "https://api.github.com",
  346. },
  347. },
  348. )
  349. require.NoError(t, err)
  350. // List all login sources
  351. sources, err := db.List(ctx, ListLoginSourceOptions{})
  352. require.NoError(t, err)
  353. assert.Equal(t, 4, len(sources), "number of sources")
  354. // Only list activated login sources
  355. sources, err = db.List(ctx, ListLoginSourceOptions{OnlyActivated: true})
  356. require.NoError(t, err)
  357. assert.Equal(t, 2, len(sources), "number of sources")
  358. }
  359. func loginSourcesResetNonDefault(t *testing.T, db *loginSources) {
  360. ctx := context.Background()
  361. mock := NewMockLoginSourceFilesStore()
  362. mock.ListFunc.SetDefaultHook(func(opts ListLoginSourceOptions) []*LoginSource {
  363. mockFile := NewMockLoginSourceFileStore()
  364. mockFile.SetGeneralFunc.SetDefaultHook(func(name, value string) {
  365. assert.Equal(t, "is_default", name)
  366. assert.Equal(t, "false", value)
  367. })
  368. return []*LoginSource{
  369. {
  370. File: mockFile,
  371. },
  372. }
  373. })
  374. setMockLoginSourceFilesStore(t, db, mock)
  375. // Create two login sources both have default on
  376. source1, err := db.Create(ctx,
  377. CreateLoginSourceOptions{
  378. Type: auth.PAM,
  379. Name: "PAM",
  380. Default: true,
  381. Config: &pam.Config{
  382. ServiceName: "PAM",
  383. },
  384. },
  385. )
  386. require.NoError(t, err)
  387. source2, err := db.Create(ctx,
  388. CreateLoginSourceOptions{
  389. Type: auth.GitHub,
  390. Name: "GitHub",
  391. Activated: true,
  392. Default: true,
  393. Config: &github.Config{
  394. APIEndpoint: "https://api.github.com",
  395. },
  396. },
  397. )
  398. require.NoError(t, err)
  399. // Set source 1 as default
  400. err = db.ResetNonDefault(ctx, source1)
  401. require.NoError(t, err)
  402. // Verify the default state
  403. source1, err = db.GetByID(ctx, source1.ID)
  404. require.NoError(t, err)
  405. assert.True(t, source1.IsDefault)
  406. source2, err = db.GetByID(ctx, source2.ID)
  407. require.NoError(t, err)
  408. assert.False(t, source2.IsDefault)
  409. }
  410. func loginSourcesSave(t *testing.T, db *loginSources) {
  411. ctx := context.Background()
  412. t.Run("save to database", func(t *testing.T) {
  413. // Create a login source with name "GitHub"
  414. source, err := db.Create(ctx,
  415. CreateLoginSourceOptions{
  416. Type: auth.GitHub,
  417. Name: "GitHub",
  418. Activated: true,
  419. Default: false,
  420. Config: &github.Config{
  421. APIEndpoint: "https://api.github.com",
  422. },
  423. },
  424. )
  425. require.NoError(t, err)
  426. source.IsActived = false
  427. source.Provider = github.NewProvider(&github.Config{
  428. APIEndpoint: "https://api2.github.com",
  429. })
  430. err = db.Save(ctx, source)
  431. require.NoError(t, err)
  432. source, err = db.GetByID(ctx, source.ID)
  433. require.NoError(t, err)
  434. assert.False(t, source.IsActived)
  435. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  436. })
  437. t.Run("save to file", func(t *testing.T) {
  438. mockFile := NewMockLoginSourceFileStore()
  439. source := &LoginSource{
  440. Provider: github.NewProvider(&github.Config{
  441. APIEndpoint: "https://api.github.com",
  442. }),
  443. File: mockFile,
  444. }
  445. err := db.Save(ctx, source)
  446. require.NoError(t, err)
  447. mockrequire.Called(t, mockFile.SaveFunc)
  448. })
  449. }