diff --git a/fileserve.go b/fileserve.go index 00858f0..c7eac21 100644 --- a/fileserve.go +++ b/fileserve.go @@ -26,7 +26,7 @@ func fileServeHandler(c web.C, w http.ResponseWriter, r *http.Request) { if !Config.allowHotlink { referer := r.Header.Get("Referer") u, _ := url.Parse(referer) - p, _ := url.Parse(Config.siteURL) + p, _ := url.Parse(getSiteURL(r)) if referer != "" && !sameOrigin(u, p) { http.Redirect(w, r, Config.sitePath+fileName, 303) return diff --git a/headers.go b/headers.go index 918a6dd..d0eeda0 100644 --- a/headers.go +++ b/headers.go @@ -2,6 +2,7 @@ package main import ( "net/http" + "net/url" "strings" ) @@ -25,3 +26,26 @@ func AddHeaders(headers []string) func(http.Handler) http.Handler { } return fn } + +func getSiteURL(r *http.Request) string { + if Config.siteURL != "" { + return Config.siteURL + } else { + u := &url.URL{} + u.Host = r.Host + + if Config.sitePath != "" { + u.Path = Config.sitePath + } + + if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" { + u.Scheme = scheme + } else if Config.certFile != "" || (r.TLS != nil && r.TLS.HandshakeComplete == true) { + u.Scheme = "https" + } else { + u.Scheme = "http" + } + + return u.String() + } +} diff --git a/pages.go b/pages.go index 9d24bba..597e5a4 100644 --- a/pages.go +++ b/pages.go @@ -36,7 +36,7 @@ func pasteHandler(c web.C, w http.ResponseWriter, r *http.Request) { } func apiDocHandler(c web.C, w http.ResponseWriter, r *http.Request) { - err := Templates["API.html"].ExecuteWriter(pongo2.Context{}, w) + err := Templates["API.html"].ExecuteWriter(pongo2.Context{"siteurl": getSiteURL(r)}, w) if err != nil { oopsHandler(c, w, r, RespHTML, "") } diff --git a/server.go b/server.go index 8c129b6..e7eee43 100644 --- a/server.go +++ b/server.go @@ -102,17 +102,21 @@ func setup() *web.Mux { log.Fatal("Could not create metadata directory:", err) } - // ensure siteURL ends wth '/' - if lastChar := Config.siteURL[len(Config.siteURL)-1:]; lastChar != "/" { - Config.siteURL = Config.siteURL + "/" - } + if Config.siteURL != "" { + // ensure siteURL ends wth '/' + if lastChar := Config.siteURL[len(Config.siteURL)-1:]; lastChar != "/" { + Config.siteURL = Config.siteURL + "/" + } - parsedUrl, err := url.Parse(Config.siteURL) - if err != nil { - log.Fatal("Could not parse siteurl:", err) - } + parsedUrl, err := url.Parse(Config.siteURL) + if err != nil { + log.Fatal("Could not parse siteurl:", err) + } - Config.sitePath = parsedUrl.Path + Config.sitePath = parsedUrl.Path + } else { + Config.sitePath = "/" + } // Template setup p2l, err := NewPongo2TemplatesLoader() @@ -121,7 +125,6 @@ func setup() *web.Mux { } TemplateSet := pongo2.NewSet("templates", p2l) TemplateSet.Globals["sitename"] = Config.siteName - TemplateSet.Globals["siteurl"] = Config.siteURL TemplateSet.Globals["sitepath"] = Config.sitePath TemplateSet.Globals["using_auth"] = Config.authFile != "" err = populateTemplatesMap(TemplateSet, Templates) @@ -193,7 +196,7 @@ func main() { "Allow hotlinking of files") flag.StringVar(&Config.siteName, "sitename", "linx", "name of the site") - flag.StringVar(&Config.siteURL, "siteurl", "http://"+Config.bind+"/", + flag.StringVar(&Config.siteURL, "siteurl", "", "site base url (including trailing slash)") flag.Int64Var(&Config.maxSize, "maxsize", 4*1024*1024*1024, "maximum upload file size in bytes (default 4GB)") @@ -232,16 +235,16 @@ func main() { log.Fatal("Could not bind: ", err) } - log.Printf("Serving over fastcgi, bound on %s, using siteurl %s", Config.bind, Config.siteURL) + log.Printf("Serving over fastcgi, bound on %s", Config.bind) fcgi.Serve(listener, mux) } else if Config.certFile != "" { - log.Printf("Serving over https, bound on %s, using siteurl %s", Config.bind, Config.siteURL) + log.Printf("Serving over https, bound on %s", Config.bind) err := graceful.ListenAndServeTLS(Config.bind, Config.certFile, Config.keyFile, mux) if err != nil { log.Fatal(err) } } else { - log.Printf("Serving over http, bound on %s, using siteurl %s", Config.bind, Config.siteURL) + log.Printf("Serving over http, bound on %s", Config.bind) err := graceful.ListenAndServe(Config.bind, mux) if err != nil { log.Fatal(err) diff --git a/server_test.go b/server_test.go index 5cb1d52..527e3be 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "crypto/tls" "encoding/json" "mime/multipart" "net/http" @@ -923,6 +924,102 @@ func TestExtension(t *testing.T) { } } +func TestInferSiteURL(t *testing.T) { + oldSiteURL := Config.siteURL + oldSitePath := Config.sitePath + Config.siteURL = "" + Config.sitePath = "/linxtest/" + + mux := setup() + w := httptest.NewRecorder() + + req, err := http.NewRequest("GET", "/API/", nil) + req.Host = "example.com:8080" + if err != nil { + t.Fatal(err) + } + + mux.ServeHTTP(w, req) + + if !strings.Contains(w.Body.String(), "http://example.com:8080/upload/") { + t.Fatal("Site URL not found properly embedded in response") + } + + Config.siteURL = oldSiteURL + Config.sitePath = oldSitePath +} + +func TestInferSiteURLProxied(t *testing.T) { + oldSiteURL := Config.siteURL + Config.siteURL = "" + + mux := setup() + w := httptest.NewRecorder() + + req, err := http.NewRequest("GET", "/API/", nil) + req.Header.Add("X-Forwarded-Proto", "https") + req.Host = "example.com:8080" + if err != nil { + t.Fatal(err) + } + + mux.ServeHTTP(w, req) + + if !strings.Contains(w.Body.String(), "https://example.com:8080/upload/") { + t.Fatal("Site URL not found properly embedded in response") + } + + Config.siteURL = oldSiteURL +} + +func TestInferSiteURLHTTPS(t *testing.T) { + oldSiteURL := Config.siteURL + oldCertFile := Config.certFile + Config.siteURL = "" + Config.certFile = "/dev/null" + + mux := setup() + w := httptest.NewRecorder() + + req, err := http.NewRequest("GET", "/API/", nil) + req.Host = "example.com" + if err != nil { + t.Fatal(err) + } + + mux.ServeHTTP(w, req) + + if !strings.Contains(w.Body.String(), "https://example.com/upload/") { + t.Fatal("Site URL not found properly embedded in response") + } + + Config.siteURL = oldSiteURL + Config.certFile = oldCertFile +} + +func TestInferSiteURLHTTPSFastCGI(t *testing.T) { + oldSiteURL := Config.siteURL + Config.siteURL = "" + + mux := setup() + w := httptest.NewRecorder() + + req, err := http.NewRequest("GET", "/API/", nil) + req.Host = "example.com" + req.TLS = &tls.ConnectionState{HandshakeComplete: true} + if err != nil { + t.Fatal(err) + } + + mux.ServeHTTP(w, req) + + if !strings.Contains(w.Body.String(), "https://example.com/upload/") { + t.Fatal("Site URL not found properly embedded in response") + } + + Config.siteURL = oldSiteURL +} + func TestShutdown(t *testing.T) { os.RemoveAll(Config.filesDir) os.RemoveAll(Config.metaDir) diff --git a/torrent.go b/torrent.go index e3c4952..2e82598 100644 --- a/torrent.go +++ b/torrent.go @@ -37,7 +37,7 @@ func hashPiece(piece []byte) []byte { return h.Sum(nil) } -func createTorrent(fileName string, filePath string) ([]byte, error) { +func createTorrent(fileName string, filePath string, r *http.Request) ([]byte, error) { chunk := make([]byte, TORRENT_PIECE_LENGTH) torrent := Torrent{ @@ -46,7 +46,7 @@ func createTorrent(fileName string, filePath string) ([]byte, error) { PieceLength: TORRENT_PIECE_LENGTH, Name: fileName, }, - UrlList: []string{fmt.Sprintf("%sselif/%s", Config.siteURL, fileName)}, + UrlList: []string{fmt.Sprintf("%sselif/%s", getSiteURL(r), fileName)}, } f, err := os.Open(filePath) @@ -89,7 +89,7 @@ func fileTorrentHandler(c web.C, w http.ResponseWriter, r *http.Request) { return } - encoded, err := createTorrent(fileName, filePath) + encoded, err := createTorrent(fileName, filePath, r) if err != nil { oopsHandler(c, w, r, RespHTML, "Could not create torrent.") return diff --git a/torrent_test.go b/torrent_test.go index 9a44eb9..03a51da 100644 --- a/torrent_test.go +++ b/torrent_test.go @@ -11,7 +11,7 @@ func TestCreateTorrent(t *testing.T) { fileName := "server.go" var decoded Torrent - encoded, err := createTorrent(fileName, fileName) + encoded, err := createTorrent(fileName, fileName, nil) if err != nil { t.Fatal(err) } @@ -47,7 +47,7 @@ func TestCreateTorrent(t *testing.T) { func TestCreateTorrentWithImage(t *testing.T) { var decoded Torrent - encoded, err := createTorrent("test.jpg", "static/images/404.jpg") + encoded, err := createTorrent("test.jpg", "static/images/404.jpg", nil) if err != nil { t.Fatal(err) } diff --git a/upload.go b/upload.go index 2fd4ec3..33d845d 100644 --- a/upload.go +++ b/upload.go @@ -46,7 +46,7 @@ type Upload struct { } func uploadPostHandler(c web.C, w http.ResponseWriter, r *http.Request) { - if !strictReferrerCheck(r, Config.siteURL, []string{"Linx-Delete-Key", "Linx-Expiry", "Linx-Randomize", "X-Requested-With"}) { + if !strictReferrerCheck(r, getSiteURL(r), []string{"Linx-Delete-Key", "Linx-Expiry", "Linx-Randomize", "X-Requested-With"}) { badRequestHandler(c, w, r) return } @@ -94,7 +94,7 @@ func uploadPostHandler(c web.C, w http.ResponseWriter, r *http.Request) { return } - js := generateJSONresponse(upload) + js := generateJSONresponse(upload, r) w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Write(js) } else { @@ -124,7 +124,7 @@ func uploadPutHandler(c web.C, w http.ResponseWriter, r *http.Request) { return } - js := generateJSONresponse(upload) + js := generateJSONresponse(upload, r) w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Write(js) } else { @@ -133,7 +133,7 @@ func uploadPutHandler(c web.C, w http.ResponseWriter, r *http.Request) { return } - fmt.Fprintf(w, Config.siteURL+upload.Filename) + fmt.Fprintf(w, getSiteURL(r)+upload.Filename) } } @@ -174,7 +174,7 @@ func uploadRemote(c web.C, w http.ResponseWriter, r *http.Request) { return } - js := generateJSONresponse(upload) + js := generateJSONresponse(upload, r) w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Write(js) } else { @@ -308,9 +308,9 @@ func generateBarename() string { return uniuri.NewLenChars(8, []byte("abcdefghijklmnopqrstuvwxyz0123456789")) } -func generateJSONresponse(upload Upload) []byte { +func generateJSONresponse(upload Upload, r *http.Request) []byte { js, _ := json.Marshal(map[string]string{ - "url": Config.siteURL + upload.Filename, + "url": getSiteURL(r) + upload.Filename, "filename": upload.Filename, "delete_key": upload.Metadata.DeleteKey, "expiry": strconv.FormatInt(upload.Metadata.Expiry.Unix(), 10),