瀏覽代碼

Code dedoublication in models/models.go

Just some code dedoublication in models/models.go
Tristan Storch 10 年之前
父節點
當前提交
bdfdf3cacb
共有 1 個文件被更改,包括 15 次插入29 次删除
  1. 15 29
      models/models.go

+ 15 - 29
models/models.go

@@ -55,11 +55,12 @@ func LoadModelsConfig() {
 	DbCfg.Path = setting.Cfg.MustValue("database", "PATH", "data/gogs.db")
 }
 
-func NewTestEngine(x *xorm.Engine) (err error) {
+func getEngine() (*xorm.Engine, error) {
+	cnnstr := ""
 	switch DbCfg.Type {
 	case "mysql":
-		x, err = xorm.NewEngine("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8",
-			DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name))
+		cnnstr = fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8",
+			DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name)
 	case "postgres":
 		var host, port = "127.0.0.1", "5432"
 		fields := strings.Split(DbCfg.Host, ":")
@@ -69,46 +70,31 @@ func NewTestEngine(x *xorm.Engine) (err error) {
 		if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 {
 			port = fields[1]
 		}
-		cnnstr := fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s",
+		cnnstr = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s",
 			DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode)
-		x, err = xorm.NewEngine("postgres", cnnstr)
 	case "sqlite3":
 		if !EnableSQLite3 {
-			return fmt.Errorf("Unknown database type: %s", DbCfg.Type)
+			return nil, fmt.Errorf("Unknown database type: %s", DbCfg.Type)
 		}
 		os.MkdirAll(path.Dir(DbCfg.Path), os.ModePerm)
-		x, err = xorm.NewEngine("sqlite3", DbCfg.Path)
+		cnnstr = DbCfg.Path
 	default:
-		return fmt.Errorf("Unknown database type: %s", DbCfg.Type)
+		return nil, fmt.Errorf("Unknown database type: %s", DbCfg.Type)
 	}
+	return xorm.NewEngine(DbCfg.Type, cnnstr)
+}
+
+func NewTestEngine(x *xorm.Engine) (err error) {
+	x, err = getEngine()
 	if err != nil {
 		return fmt.Errorf("models.init(fail to conntect database): %v", err)
 	}
+
 	return x.Sync(tables...)
 }
 
 func SetEngine() (err error) {
-	switch DbCfg.Type {
-	case "mysql":
-		x, err = xorm.NewEngine("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8",
-			DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name))
-	case "postgres":
-		var host, port = "127.0.0.1", "5432"
-		fields := strings.Split(DbCfg.Host, ":")
-		if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 {
-			host = fields[0]
-		}
-		if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 {
-			port = fields[1]
-		}
-		x, err = xorm.NewEngine("postgres", fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s",
-			DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode))
-	case "sqlite3":
-		os.MkdirAll(path.Dir(DbCfg.Path), os.ModePerm)
-		x, err = xorm.NewEngine("sqlite3", DbCfg.Path)
-	default:
-		return fmt.Errorf("Unknown database type: %s", DbCfg.Type)
-	}
+	x, err = getEngine()
 	if err != nil {
 		return fmt.Errorf("models.init(fail to conntect database): %v", err)
 	}