dbtest.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package dbtest
  5. import (
  6. "database/sql"
  7. "fmt"
  8. "os"
  9. "path/filepath"
  10. "testing"
  11. "time"
  12. "github.com/stretchr/testify/require"
  13. "gorm.io/gorm"
  14. "gorm.io/gorm/schema"
  15. "gogs.io/gogs/internal/conf"
  16. "gogs.io/gogs/internal/dbutil"
  17. )
  18. // NewDB creates a new test database and initializes the given list of tables
  19. // for the suite. The test database is dropped after testing is completed unless
  20. // failed.
  21. func NewDB(t *testing.T, suite string, tables ...any) *gorm.DB {
  22. dbType := os.Getenv("GOGS_DATABASE_TYPE")
  23. var dbName string
  24. var dbOpts conf.DatabaseOpts
  25. var cleanup func(db *gorm.DB)
  26. switch dbType {
  27. case "mysql":
  28. dbOpts = conf.DatabaseOpts{
  29. Type: "mysql",
  30. Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
  31. Name: dbName,
  32. User: os.Getenv("MYSQL_USER"),
  33. Password: os.Getenv("MYSQL_PASSWORD"),
  34. }
  35. dsn, err := dbutil.NewDSN(dbOpts)
  36. require.NoError(t, err)
  37. sqlDB, err := sql.Open("mysql", dsn)
  38. require.NoError(t, err)
  39. // Set up test database
  40. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  41. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
  42. require.NoError(t, err)
  43. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
  44. require.NoError(t, err)
  45. dbOpts.Name = dbName
  46. cleanup = func(_ *gorm.DB) {
  47. _, _ = sqlDB.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
  48. _ = sqlDB.Close()
  49. }
  50. case "postgres":
  51. dbOpts = conf.DatabaseOpts{
  52. Type: "postgres",
  53. Host: os.ExpandEnv("$PGHOST:$PGPORT"),
  54. Name: dbName,
  55. Schema: "public",
  56. User: os.Getenv("PGUSER"),
  57. Password: os.Getenv("PGPASSWORD"),
  58. SSLMode: os.Getenv("PGSSLMODE"),
  59. }
  60. dsn, err := dbutil.NewDSN(dbOpts)
  61. require.NoError(t, err)
  62. sqlDB, err := sql.Open("pgx", dsn)
  63. require.NoError(t, err)
  64. // Set up test database
  65. dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
  66. _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
  67. require.NoError(t, err)
  68. _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
  69. require.NoError(t, err)
  70. dbOpts.Name = dbName
  71. cleanup = func(_ *gorm.DB) {
  72. _, _ = sqlDB.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
  73. _ = sqlDB.Close()
  74. }
  75. case "sqlite":
  76. dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
  77. dbOpts = conf.DatabaseOpts{
  78. Type: "sqlite",
  79. Path: dbName,
  80. }
  81. cleanup = func(db *gorm.DB) {
  82. sqlDB, err := db.DB()
  83. if err == nil {
  84. _ = sqlDB.Close()
  85. }
  86. _ = os.Remove(dbName)
  87. }
  88. default:
  89. dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
  90. dbOpts = conf.DatabaseOpts{
  91. Type: "sqlite3",
  92. Path: dbName,
  93. }
  94. cleanup = func(db *gorm.DB) {
  95. sqlDB, err := db.DB()
  96. if err == nil {
  97. _ = sqlDB.Close()
  98. }
  99. _ = os.Remove(dbName)
  100. }
  101. }
  102. now := time.Now().UTC().Truncate(time.Second)
  103. db, err := dbutil.OpenDB(
  104. dbOpts,
  105. &gorm.Config{
  106. SkipDefaultTransaction: true,
  107. NamingStrategy: schema.NamingStrategy{
  108. SingularTable: true,
  109. },
  110. NowFunc: func() time.Time {
  111. return now
  112. },
  113. },
  114. )
  115. require.NoError(t, err)
  116. t.Cleanup(func() {
  117. if t.Failed() {
  118. t.Logf("Database %q left intact for inspection", dbName)
  119. return
  120. }
  121. cleanup(db)
  122. })
  123. err = db.Migrator().AutoMigrate(tables...)
  124. require.NoError(t, err)
  125. return db
  126. }