notices_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. // Copyright 2023 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "gorm.io/gorm"
  12. "gogs.io/gogs/internal/dbtest"
  13. )
  14. func TestNotice_BeforeCreate(t *testing.T) {
  15. now := time.Now()
  16. db := &gorm.DB{
  17. Config: &gorm.Config{
  18. SkipDefaultTransaction: true,
  19. NowFunc: func() time.Time {
  20. return now
  21. },
  22. },
  23. }
  24. t.Run("CreatedUnix has been set", func(t *testing.T) {
  25. notice := &Notice{
  26. CreatedUnix: 1,
  27. }
  28. _ = notice.BeforeCreate(db)
  29. assert.Equal(t, int64(1), notice.CreatedUnix)
  30. })
  31. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  32. notice := &Notice{}
  33. _ = notice.BeforeCreate(db)
  34. assert.Equal(t, db.NowFunc().Unix(), notice.CreatedUnix)
  35. })
  36. }
  37. func TestNotice_AfterFind(t *testing.T) {
  38. now := time.Now()
  39. db := &gorm.DB{
  40. Config: &gorm.Config{
  41. SkipDefaultTransaction: true,
  42. NowFunc: func() time.Time {
  43. return now
  44. },
  45. },
  46. }
  47. notice := &Notice{
  48. CreatedUnix: now.Unix(),
  49. }
  50. _ = notice.AfterFind(db)
  51. assert.Equal(t, notice.CreatedUnix, notice.Created.Unix())
  52. }
  53. func TestNotices(t *testing.T) {
  54. if testing.Short() {
  55. t.Skip()
  56. }
  57. t.Parallel()
  58. tables := []any{new(Notice)}
  59. db := &notices{
  60. DB: dbtest.NewDB(t, "notices", tables...),
  61. }
  62. for _, tc := range []struct {
  63. name string
  64. test func(t *testing.T, db *notices)
  65. }{
  66. {"Create", noticesCreate},
  67. {"DeleteByIDs", noticesDeleteByIDs},
  68. {"DeleteAll", noticesDeleteAll},
  69. {"List", noticesList},
  70. {"Count", noticesCount},
  71. } {
  72. t.Run(tc.name, func(t *testing.T) {
  73. t.Cleanup(func() {
  74. err := clearTables(t, db.DB, tables...)
  75. require.NoError(t, err)
  76. })
  77. tc.test(t, db)
  78. })
  79. if t.Failed() {
  80. break
  81. }
  82. }
  83. }
  84. func noticesCreate(t *testing.T, db *notices) {
  85. ctx := context.Background()
  86. err := db.Create(ctx, NoticeTypeRepository, "test")
  87. require.NoError(t, err)
  88. count := db.Count(ctx)
  89. assert.Equal(t, int64(1), count)
  90. }
  91. func noticesDeleteByIDs(t *testing.T, db *notices) {
  92. ctx := context.Background()
  93. err := db.Create(ctx, NoticeTypeRepository, "test")
  94. require.NoError(t, err)
  95. notices, err := db.List(ctx, 1, 10)
  96. require.NoError(t, err)
  97. ids := make([]int64, 0, len(notices))
  98. for _, notice := range notices {
  99. ids = append(ids, notice.ID)
  100. }
  101. // Non-existing IDs should be ignored
  102. ids = append(ids, 404)
  103. err = db.DeleteByIDs(ctx, ids...)
  104. require.NoError(t, err)
  105. count := db.Count(ctx)
  106. assert.Equal(t, int64(0), count)
  107. }
  108. func noticesDeleteAll(t *testing.T, db *notices) {
  109. ctx := context.Background()
  110. err := db.Create(ctx, NoticeTypeRepository, "test")
  111. require.NoError(t, err)
  112. err = db.DeleteAll(ctx)
  113. require.NoError(t, err)
  114. count := db.Count(ctx)
  115. assert.Equal(t, int64(0), count)
  116. }
  117. func noticesList(t *testing.T, db *notices) {
  118. ctx := context.Background()
  119. err := db.Create(ctx, NoticeTypeRepository, "test 1")
  120. require.NoError(t, err)
  121. err = db.Create(ctx, NoticeTypeRepository, "test 2")
  122. require.NoError(t, err)
  123. got1, err := db.List(ctx, 1, 1)
  124. require.NoError(t, err)
  125. require.Len(t, got1, 1)
  126. got2, err := db.List(ctx, 2, 1)
  127. require.NoError(t, err)
  128. require.Len(t, got2, 1)
  129. assert.True(t, got1[0].ID > got2[0].ID)
  130. got, err := db.List(ctx, 1, 3)
  131. require.NoError(t, err)
  132. require.Len(t, got, 2)
  133. }
  134. func noticesCount(t *testing.T, db *notices) {
  135. ctx := context.Background()
  136. count := db.Count(ctx)
  137. assert.Equal(t, int64(0), count)
  138. err := db.Create(ctx, NoticeTypeRepository, "test")
  139. require.NoError(t, err)
  140. count = db.Count(ctx)
  141. assert.Equal(t, int64(1), count)
  142. }