Browse Source

add ssl support for web

Lunny Xiao 11 years ago
parent
commit
493b0c5ac2
4 changed files with 118 additions and 7 deletions
  1. 4 1
      conf/app.ini
  2. 94 2
      routers/repo/repo.go
  3. 6 0
      templates/status/401.tmpl
  4. 14 4
      web.go

+ 4 - 1
conf/app.ini

@@ -12,10 +12,13 @@ LANG_IGNS = Google Go|C|C++|Python|Ruby|C Sharp
 LICENSES = Apache v2 License|GPL v2|MIT License|Affero GPL|Artistic License 2.0|BSD (3-Clause) License
 
 [server]
+PROTOCOL = http
 DOMAIN = localhost
-ROOT_URL = http://%(DOMAIN)s:%(HTTP_PORT)s/
+ROOT_URL = %(PROTOCOL)://%(DOMAIN)s:%(HTTP_PORT)s/
 HTTP_ADDR = 
 HTTP_PORT = 3000
+CERT_FILE = cert.pem
+KEY_FILE = key.pem
 
 [database]
 ; Either "mysql", "postgres" or "sqlite3"(binary release only), it's your choice

+ 94 - 2
routers/repo/repo.go

@@ -5,6 +5,8 @@
 package repo
 
 import (
+	"encoding/base64"
+	"errors"
 	"fmt"
 	"path"
 	"path/filepath"
@@ -237,15 +239,105 @@ func SingleDownload(ctx *middleware.Context, params martini.Params) {
 	ctx.Res.Write(data)
 }
 
-func Http(ctx *middleware.Context, params martini.Params) {
-	// TODO: access check
+func basicEncode(username, password string) string {
+	auth := username + ":" + password
+	return base64.StdEncoding.EncodeToString([]byte(auth))
+}
+
+func basicDecode(encoded string) (user string, name string, err error) {
+	var s []byte
+	s, err = base64.StdEncoding.DecodeString(encoded)
+	if err != nil {
+		return
+	}
+
+	a := strings.Split(string(s), ":")
+	if len(a) == 2 {
+		user, name = a[0], a[1]
+	} else {
+		err = errors.New("decode failed")
+	}
+	return
+}
+
+func authRequired(ctx *middleware.Context) {
+	ctx.ResponseWriter.Header().Set("WWW-Authenticate", `Basic realm="Gogs Auth"`)
+	ctx.Data["ErrorMsg"] = "no basic auth and digit auth"
+	ctx.HTML(401, fmt.Sprintf("status/401"))
+}
 
+func Http(ctx *middleware.Context, params martini.Params) {
 	username := params["username"]
 	reponame := params["reponame"]
 	if strings.HasSuffix(reponame, ".git") {
 		reponame = reponame[:len(reponame)-4]
 	}
 
+	repoUser, err := models.GetUserByName(username)
+	if err != nil {
+		ctx.Handle(500, "repo.GetUserByName", nil)
+		return
+	}
+
+	repo, err := models.GetRepositoryByName(repoUser.Id, reponame)
+	if err != nil {
+		ctx.Handle(500, "repo.GetRepositoryByName", nil)
+		return
+	}
+
+	isPull := webdav.IsPullMethod(ctx.Req.Method)
+	var askAuth = !(!repo.IsPrivate && isPull)
+
+	//authRequired(ctx)
+	//return
+
+	// check access
+	if askAuth {
+		// check digit auth
+
+		// check basic auth
+		baHead := ctx.Req.Header.Get("Authorization")
+		if baHead != "" {
+			auths := strings.Fields(baHead)
+			if len(auths) != 2 || auths[0] != "Basic" {
+				ctx.Handle(401, "no basic auth and digit auth", nil)
+				return
+			}
+			authUsername, passwd, err := basicDecode(auths[1])
+			if err != nil {
+				ctx.Handle(401, "no basic auth and digit auth", nil)
+				return
+			}
+
+			authUser, err := models.GetUserByName(authUsername)
+			if err != nil {
+				ctx.Handle(401, "no basic auth and digit auth", nil)
+				return
+			}
+
+			newUser := &models.User{Passwd: passwd}
+			newUser.EncodePasswd()
+			if authUser.Passwd != newUser.Passwd {
+				ctx.Handle(401, "no basic auth and digit auth", nil)
+				return
+			}
+
+			var tp = models.AU_WRITABLE
+			if isPull {
+				tp = models.AU_READABLE
+			}
+
+			has, err := models.HasAccess(authUsername, username+"/"+reponame, tp)
+			if err != nil || !has {
+				ctx.Handle(401, "no basic auth and digit auth", nil)
+				return
+			}
+		} else {
+			authRequired(ctx)
+			return
+		}
+	}
+
 	prefix := path.Join("/", username, params["reponame"])
 	server := webdav.NewServer(
 		models.RepoPath(username, reponame),

+ 6 - 0
templates/status/401.tmpl

@@ -0,0 +1,6 @@
+{{template "base/head" .}}
+{{template "base/navbar" .}}
+<div class="container">
+	401 Unauthorized
+</div>
+{{template "base/footer" .}}

+ 14 - 4
web.go

@@ -169,12 +169,22 @@ func runWeb(*cli.Context) {
 	// Not found handler.
 	m.NotFound(routers.NotFound)
 
+	protocol := base.Cfg.MustValue("server", "PROTOCOL", "http")
 	listenAddr := fmt.Sprintf("%s:%s",
 		base.Cfg.MustValue("server", "HTTP_ADDR"),
 		base.Cfg.MustValue("server", "HTTP_PORT", "3000"))
-	log.Info("Listen: %s", listenAddr)
-	if err := http.ListenAndServe(listenAddr, m); err != nil {
-		fmt.Println(err.Error())
-		//log.Critical(err.Error()) // not working now
+
+	if protocol == "http" {
+		log.Info("Listen: http://%s", listenAddr)
+		if err := http.ListenAndServe(listenAddr, m); err != nil {
+			fmt.Println(err.Error())
+			//log.Critical(err.Error()) // not working now
+		}
+	} else if protocol == "https" {
+		log.Info("Listen: https://%s", listenAddr)
+		if err := http.ListenAndServeTLS(listenAddr, base.Cfg.MustValue("server", "CERT_FILE"),
+			base.Cfg.MustValue("server", "KEY_FILE"), m); err != nil {
+			fmt.Println(err.Error())
+		}
 	}
 }