169 lines
5.0 KiB
Go
169 lines
5.0 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
|
)
|
|
|
|
var db *sql.DB
|
|
|
|
func main() {
|
|
databaseURL := os.Getenv("DATABASE_URL")
|
|
if databaseURL == "" {
|
|
log.Fatal("DATABASE_URL is required")
|
|
}
|
|
|
|
var err error
|
|
db, err = sql.Open("pgx", databaseURL)
|
|
if err != nil {
|
|
log.Fatalf("failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
db.SetMaxOpenConns(25)
|
|
db.SetMaxIdleConns(5)
|
|
db.SetConnMaxLifetime(time.Minute * 5)
|
|
|
|
if err := db.Ping(); err != nil {
|
|
log.Fatalf("failed to ping database: %v", err)
|
|
}
|
|
|
|
http.HandleFunc("/", handleRequest)
|
|
|
|
port := os.Getenv("PORT")
|
|
if port == "" {
|
|
port = "2020"
|
|
}
|
|
|
|
log.Printf("[resolver] listening on :%s", port)
|
|
if err := http.ListenAndServe(":"+port, nil); err != nil {
|
|
log.Fatalf("[resolver] failed to listen: %v", err)
|
|
}
|
|
}
|
|
|
|
func handleRequest(w http.ResponseWriter, r *http.Request) {
|
|
host := r.Host
|
|
if idx := strings.Index(host, ":"); idx != -1 {
|
|
host = host[:idx]
|
|
}
|
|
|
|
log.Printf("[resolver] host=%s method=%s path=%s", host, r.Method, r.URL.RequestURI())
|
|
|
|
parts := strings.Split(host, ".")
|
|
if len(parts) < 3 {
|
|
http.Error(w, "invalid host", http.StatusBadRequest)
|
|
return
|
|
}
|
|
subdomain := parts[0]
|
|
|
|
var tailscaleIp string
|
|
var port int
|
|
var nodeStatus string
|
|
err := db.QueryRow(`
|
|
SELECT n."tailscaleIp", i.port, n.status
|
|
FROM "Instance" i
|
|
JOIN "Node" n ON i."nodeId" = n.id
|
|
WHERE i.id = $1
|
|
`, subdomain).Scan(&tailscaleIp, &port, &nodeStatus)
|
|
if err == sql.ErrNoRows {
|
|
log.Printf("[resolver] instance not found: %s", subdomain)
|
|
http.Error(w, "instance not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
if err != nil {
|
|
log.Printf("[resolver] database error: %v", err)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if nodeStatus != "online" || tailscaleIp == "" {
|
|
log.Printf("[resolver] node offline for instance %s (status=%s ip=%s)", subdomain, nodeStatus, tailscaleIp)
|
|
http.Error(w, "node offline", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
upstream := fmt.Sprintf("http://%s:%d", tailscaleIp, port)
|
|
log.Printf("[resolver] proxying %s -> %s", host, upstream)
|
|
|
|
proxyURL, err := url.Parse(upstream)
|
|
if err != nil {
|
|
log.Printf("[resolver] invalid upstream URL %s: %v", upstream, err)
|
|
http.Error(w, "invalid upstream", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
proxy := httputil.NewSingleHostReverseProxy(proxyURL)
|
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
log.Printf("[resolver] proxy error for %s: %v", upstream, err)
|
|
http.Error(w, "upstream unavailable", http.StatusBadGateway)
|
|
}
|
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
|
// Do not follow 3xx redirects here; rewrite the Location header to HTTPS
|
|
// and let the browser follow it. Caddy will terminate TLS correctly.
|
|
location := resp.Header.Get("Location")
|
|
if location != "" {
|
|
newLocation := strings.ReplaceAll(location, fmt.Sprintf("http://%s", host), fmt.Sprintf("https://%s", host))
|
|
newLocation = strings.ReplaceAll(newLocation, fmt.Sprintf("http://%s:%d", tailscaleIp, port), fmt.Sprintf("https://%s", host))
|
|
if newLocation != location {
|
|
log.Printf("[resolver] rewriting Location (status=%d): %s -> %s", resp.StatusCode, location, newLocation)
|
|
resp.Header.Set("Location", newLocation)
|
|
} else if resp.StatusCode >= 300 && resp.StatusCode < 400 {
|
|
log.Printf("[resolver] passing through %d Location: %s", resp.StatusCode, location)
|
|
}
|
|
} else if resp.StatusCode >= 300 && resp.StatusCode < 400 {
|
|
log.Printf("[resolver] passing through %d without Location header", resp.StatusCode)
|
|
}
|
|
|
|
// Rewrite http:// to https:// in HTML/CSS/JS bodies to avoid mixed content.
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if strings.Contains(contentType, "text/html") ||
|
|
strings.Contains(contentType, "text/css") ||
|
|
strings.Contains(contentType, "application/javascript") {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp.Body.Close()
|
|
|
|
newBody := strings.ReplaceAll(string(body), fmt.Sprintf("http://%s", host), fmt.Sprintf("https://%s", host))
|
|
newBody = strings.ReplaceAll(newBody, fmt.Sprintf("http://%s:%d", tailscaleIp, port), fmt.Sprintf("https://%s", host))
|
|
|
|
if newBody != string(body) {
|
|
log.Printf("[resolver] rewrote %d bytes in %s body", len(newBody)-len(body), contentType)
|
|
}
|
|
|
|
resp.Body = io.NopCloser(strings.NewReader(newBody))
|
|
resp.ContentLength = int64(len(newBody))
|
|
resp.Header.Set("Content-Length", strconv.Itoa(len(newBody)))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
proxy.Director = func(req *http.Request) {
|
|
req.URL.Scheme = proxyURL.Scheme
|
|
req.URL.Host = proxyURL.Host
|
|
req.URL.Path = r.URL.Path
|
|
req.URL.RawQuery = r.URL.RawQuery
|
|
req.Host = host
|
|
req.Header.Set("X-Forwarded-Host", host)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
req.Header.Set("X-Forwarded-Port", "443")
|
|
if req.Header.Get("X-Forwarded-For") == "" {
|
|
req.Header.Set("X-Forwarded-For", r.RemoteAddr)
|
|
}
|
|
}
|
|
|
|
proxy.ServeHTTP(w, r)
|
|
}
|