Pārlūkot izejas kodu

db: use `context` for backup and restore (#7044)

Joe Chen 2 gadi atpakaļ
vecāks
revīzija
75fbb82440
4 mainītis faili ar 57 papildinājumiem un 30 dzēšanām
  1. 2 1
      internal/cmd/backup.go
  2. 2 1
      internal/cmd/restore.go
  3. 44 17
      internal/db/backup.go
  4. 9 11
      internal/db/backup_test.go

+ 2 - 1
internal/cmd/backup.go

@@ -5,6 +5,7 @@
 package cmd
 
 import (
+	"context"
 	"fmt"
 	"io/ioutil"
 	"os"
@@ -94,7 +95,7 @@ func runBackup(c *cli.Context) error {
 
 	// Database
 	dbDir := filepath.Join(rootDir, "db")
-	if err = db.DumpDatabase(conn, dbDir, c.Bool("verbose")); err != nil {
+	if err = db.DumpDatabase(context.Background(), conn, dbDir, c.Bool("verbose")); err != nil {
 		log.Fatal("Failed to dump database: %v", err)
 	}
 	if err = z.AddDir(archiveRootDir+"/db", dbDir); err != nil {

+ 2 - 1
internal/cmd/restore.go

@@ -5,6 +5,7 @@
 package cmd
 
 import (
+	"context"
 	"os"
 	"path"
 	"path/filepath"
@@ -114,7 +115,7 @@ func runRestore(c *cli.Context) error {
 
 	// Database
 	dbDir := path.Join(archivePath, "db")
-	if err = db.ImportDatabase(conn, dbDir, c.Bool("verbose")); err != nil {
+	if err = db.ImportDatabase(context.Background(), conn, dbDir, c.Bool("verbose")); err != nil {
 		log.Fatal("Failed to import database: %v", err)
 	}
 

+ 44 - 17
internal/db/backup.go

@@ -3,6 +3,7 @@ package db
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"fmt"
 	"io"
 	"os"
@@ -30,18 +31,24 @@ func getTableType(t interface{}) string {
 }
 
 // DumpDatabase dumps all data from database to file system in JSON Lines format.
-func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
+func DumpDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
 	err := os.MkdirAll(dirPath, os.ModePerm)
 	if err != nil {
 		return err
 	}
 
-	err = dumpLegacyTables(dirPath, verbose)
+	err = dumpLegacyTables(ctx, dirPath, verbose)
 	if err != nil {
 		return errors.Wrap(err, "dump legacy tables")
 	}
 
 	for _, table := range Tables {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+
 		tableName := getTableType(table)
 		if verbose {
 			log.Trace("Dumping table %q...", tableName)
@@ -55,7 +62,7 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
 			}
 			defer func() { _ = f.Close() }()
 
-			return dumpTable(db, table, f)
+			return dumpTable(ctx, db, table, f)
 		}()
 		if err != nil {
 			return errors.Wrapf(err, "dump table %q", tableName)
@@ -65,11 +72,13 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
 	return nil
 }
 
-func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
-	query := db.Model(table).Order("id ASC")
+func dumpTable(ctx context.Context, db *gorm.DB, table interface{}, w io.Writer) error {
+	query := db.WithContext(ctx).Model(table)
 	switch table.(type) {
 	case *LFSObject:
-		query = db.Model(table).Order("repo_id, oid ASC")
+		query = query.Order("repo_id, oid ASC")
+	default:
+		query = query.Order("id ASC")
 	}
 
 	rows, err := query.Rows()
@@ -98,10 +107,16 @@ func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
 	return rows.Err()
 }
 
-func dumpLegacyTables(dirPath string, verbose bool) error {
+func dumpLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
 	// Purposely create a local variable to not modify global variable
 	legacyTables := append(legacyTables, new(Version))
 	for _, table := range legacyTables {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+
 		tableName := getTableType(table)
 		if verbose {
 			log.Trace("Dumping table %q...", tableName)
@@ -113,7 +128,7 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
 			return fmt.Errorf("create JSON file: %v", err)
 		}
 
-		if err = x.Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
+		if err = x.Context(ctx).Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
 			return jsoniter.NewEncoder(f).Encode(bean)
 		}); err != nil {
 			_ = f.Close()
@@ -125,13 +140,19 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
 }
 
 // ImportDatabase imports data from backup archive in JSON Lines format.
-func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
-	err := importLegacyTables(dirPath, verbose)
+func ImportDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
+	err := importLegacyTables(ctx, dirPath, verbose)
 	if err != nil {
 		return errors.Wrap(err, "import legacy tables")
 	}
 
 	for _, table := range Tables {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+
 		tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
 		err := func() error {
 			tableFile := filepath.Join(dirPath, tableName+".json")
@@ -150,7 +171,7 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
 			}
 			defer func() { _ = f.Close() }()
 
-			return importTable(db, table, f)
+			return importTable(ctx, db, table, f)
 		}()
 		if err != nil {
 			return errors.Wrapf(err, "import table %q", tableName)
@@ -160,13 +181,13 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
 	return nil
 }
 
-func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
-	err := db.Migrator().DropTable(table)
+func importTable(ctx context.Context, db *gorm.DB, table interface{}, r io.Reader) error {
+	err := db.WithContext(ctx).Migrator().DropTable(table)
 	if err != nil {
 		return errors.Wrap(err, "drop table")
 	}
 
-	err = db.Migrator().AutoMigrate(table)
+	err = db.WithContext(ctx).Migrator().AutoMigrate(table)
 	if err != nil {
 		return errors.Wrap(err, "auto migrate")
 	}
@@ -191,7 +212,7 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
 			return errors.Wrap(err, "unmarshal JSON to struct")
 		}
 
-		err = db.Create(elem).Error
+		err = db.WithContext(ctx).Create(elem).Error
 		if err != nil {
 			return errors.Wrap(err, "create row")
 		}
@@ -200,14 +221,14 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
 	// PostgreSQL needs manually reset table sequence for auto increment keys
 	if conf.UsePostgreSQL && !skipResetIDSeq[rawTableName] {
 		seqName := rawTableName + "_id_seq"
-		if _, err = x.Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
+		if _, err = x.Context(ctx).Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
 			return errors.Wrapf(err, "reset table %q.%q", rawTableName, seqName)
 		}
 	}
 	return nil
 }
 
-func importLegacyTables(dirPath string, verbose bool) error {
+func importLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
 	snakeMapper := core.SnakeMapper{}
 
 	skipInsertProcessors := map[string]bool{
@@ -218,6 +239,12 @@ func importLegacyTables(dirPath string, verbose bool) error {
 	// Purposely create a local variable to not modify global variable
 	legacyTables := append(legacyTables, new(Version))
 	for _, table := range legacyTables {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		default:
+		}
+
 		tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
 		tableFile := filepath.Join(dirPath, tableName+".json")
 		if !osutil.IsFile(tableFile) {

+ 9 - 11
internal/db/backup_test.go

@@ -6,12 +6,14 @@ package db
 
 import (
 	"bytes"
+	"context"
 	"os"
 	"path/filepath"
 	"testing"
 	"time"
 
 	"github.com/pkg/errors"
+	"github.com/stretchr/testify/require"
 	"gorm.io/gorm"
 
 	"gogs.io/gogs/internal/auth"
@@ -22,7 +24,7 @@ import (
 	"gogs.io/gogs/internal/testutil"
 )
 
-func Test_dumpAndImport(t *testing.T) {
+func TestDumpAndImport(t *testing.T) {
 	if testing.Short() {
 		t.Skip()
 	}
@@ -43,8 +45,6 @@ func Test_dumpAndImport(t *testing.T) {
 }
 
 func setupDBToDump(t *testing.T, db *gorm.DB) {
-	t.Helper()
-
 	vals := []interface{}{
 		&Access{
 			ID:     1,
@@ -126,31 +126,29 @@ func setupDBToDump(t *testing.T, db *gorm.DB) {
 	}
 	for _, val := range vals {
 		err := db.Create(val).Error
-		if err != nil {
-			t.Fatal(err)
-		}
+		require.NoError(t, err)
 	}
 }
 
 func dumpTables(t *testing.T, db *gorm.DB) {
-	t.Helper()
+	ctx := context.Background()
 
 	for _, table := range Tables {
 		tableName := getTableType(table)
 
 		var buf bytes.Buffer
-		err := dumpTable(db, table, &buf)
+		err := dumpTable(ctx, db, table, &buf)
 		if err != nil {
 			t.Fatalf("%s: %v", tableName, err)
 		}
 
 		golden := filepath.Join("testdata", "backup", tableName+".golden.json")
-		testutil.AssertGolden(t, golden, testutil.Update("Test_dumpAndImport"), buf.String())
+		testutil.AssertGolden(t, golden, testutil.Update("TestDumpAndImport"), buf.String())
 	}
 }
 
 func importTables(t *testing.T, db *gorm.DB) {
-	t.Helper()
+	ctx := context.Background()
 
 	for _, table := range Tables {
 		tableName := getTableType(table)
@@ -163,7 +161,7 @@ func importTables(t *testing.T, db *gorm.DB) {
 			}
 			defer func() { _ = f.Close() }()
 
-			return importTable(db, table, f)
+			return importTable(ctx, db, table, f)
 		}()
 		if err != nil {
 			t.Fatalf("%s: %v", tableName, err)