package main import ( "database/sql" "fmt" "io" "log" "net/http" "net/http/httputil" "net/url" "os" "strconv" "strings" "time" _ "github.com/jackc/pgx/v5/stdlib" ) var db *sql.DB func main() { databaseURL := os.Getenv("DATABASE_URL") if databaseURL == "" { log.Fatal("DATABASE_URL is required") } var err error db, err = sql.Open("pgx", databaseURL) if err != nil { log.Fatalf("failed to open database: %v", err) } defer db.Close() db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(time.Minute * 5) if err := db.Ping(); err != nil { log.Fatalf("failed to ping database: %v", err) } http.HandleFunc("/", handleRequest) port := os.Getenv("PORT") if port == "" { port = "2020" } log.Printf("[resolver] listening on :%s", port) if err := http.ListenAndServe(":"+port, nil); err != nil { log.Fatalf("[resolver] failed to listen: %v", err) } } func handleRequest(w http.ResponseWriter, r *http.Request) { host := r.Host if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } log.Printf("[resolver] host=%s method=%s path=%s", host, r.Method, r.URL.RequestURI()) parts := strings.Split(host, ".") if len(parts) < 3 { http.Error(w, "invalid host", http.StatusBadRequest) return } subdomain := parts[0] var tailscaleIp string var port int var nodeStatus string err := db.QueryRow(` SELECT n."tailscaleIp", i.port, n.status FROM "Instance" i JOIN "Node" n ON i."nodeId" = n.id WHERE i.id = $1 `, subdomain).Scan(&tailscaleIp, &port, &nodeStatus) if err == sql.ErrNoRows { log.Printf("[resolver] instance not found: %s", subdomain) http.Error(w, "instance not found", http.StatusNotFound) return } if err != nil { log.Printf("[resolver] database error: %v", err) http.Error(w, "internal error", http.StatusInternalServerError) return } if nodeStatus != "online" || tailscaleIp == "" { log.Printf("[resolver] node offline for instance %s (status=%s ip=%s)", subdomain, nodeStatus, tailscaleIp) http.Error(w, "node offline", http.StatusServiceUnavailable) return } upstream := fmt.Sprintf("http://%s:%d", tailscaleIp, port) log.Printf("[resolver] proxying %s -> %s", host, upstream) proxyURL, err := url.Parse(upstream) if err != nil { log.Printf("[resolver] invalid upstream URL %s: %v", upstream, err) http.Error(w, "invalid upstream", http.StatusInternalServerError) return } proxy := httputil.NewSingleHostReverseProxy(proxyURL) proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { log.Printf("[resolver] proxy error for %s: %v", upstream, err) http.Error(w, "upstream unavailable", http.StatusBadGateway) } proxy.ModifyResponse = func(resp *http.Response) error { // Do not follow 3xx redirects here; rewrite the Location header to HTTPS // and let the browser follow it. Caddy will terminate TLS correctly. location := resp.Header.Get("Location") if location != "" { newLocation := strings.ReplaceAll(location, fmt.Sprintf("http://%s", host), fmt.Sprintf("https://%s", host)) newLocation = strings.ReplaceAll(newLocation, fmt.Sprintf("http://%s:%d", tailscaleIp, port), fmt.Sprintf("https://%s", host)) if newLocation != location { log.Printf("[resolver] rewriting Location (status=%d): %s -> %s", resp.StatusCode, location, newLocation) resp.Header.Set("Location", newLocation) } else if resp.StatusCode >= 300 && resp.StatusCode < 400 { log.Printf("[resolver] passing through %d Location: %s", resp.StatusCode, location) } } else if resp.StatusCode >= 300 && resp.StatusCode < 400 { log.Printf("[resolver] passing through %d without Location header", resp.StatusCode) } // Rewrite http:// to https:// in HTML/CSS/JS bodies to avoid mixed content. contentType := resp.Header.Get("Content-Type") if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "text/css") || strings.Contains(contentType, "application/javascript") { body, err := io.ReadAll(resp.Body) if err != nil { return err } resp.Body.Close() newBody := strings.ReplaceAll(string(body), fmt.Sprintf("http://%s", host), fmt.Sprintf("https://%s", host)) newBody = strings.ReplaceAll(newBody, fmt.Sprintf("http://%s:%d", tailscaleIp, port), fmt.Sprintf("https://%s", host)) if newBody != string(body) { log.Printf("[resolver] rewrote %d bytes in %s body", len(newBody)-len(body), contentType) } resp.Body = io.NopCloser(strings.NewReader(newBody)) resp.ContentLength = int64(len(newBody)) resp.Header.Set("Content-Length", strconv.Itoa(len(newBody))) } return nil } proxy.Director = func(req *http.Request) { req.URL.Scheme = proxyURL.Scheme req.URL.Host = proxyURL.Host req.URL.Path = r.URL.Path req.URL.RawQuery = r.URL.RawQuery req.Host = host req.Header.Set("X-Forwarded-Host", host) req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Port", "443") if req.Header.Get("X-Forwarded-For") == "" { req.Header.Set("X-Forwarded-For", r.RemoteAddr) } } proxy.ServeHTTP(w, r) }