diff --git a/app/routes.go b/app/routes.go index 7f5e83c..473b32e 100644 --- a/app/routes.go +++ b/app/routes.go @@ -19,6 +19,7 @@ func (a *App) SetupRoutes() { tmpls := a.getTemplates() modtmpl.Init(setting, tmpls) + handlers.SetDBClient(a.db) // run this before declaring any handler funcs. handlers.InitHandlers(setting) @@ -48,10 +49,10 @@ func (a *App) SetupRoutes() { e.POST("/signup", handlers.SignupPost(a.db)) e.GET("/home", handlers.Home(a.db)) - e.GET("/manage/users", handlers.ManageUsers(a.db)) - e.GET("/manage/users/new", handlers.ManageUsers(a.db)) - e.GET("/manage/users/:id", handlers.ViewUser(a.db)) - e.POST("/manage/users/create", handlers.CreateUser(a.db)) + e.GET("/manage/users", handlers.ManageUsers(), handlers.MiddlewareSession) + e.GET("/manage/users/new", handlers.ManageUsers(), handlers.MiddlewareSession) + e.GET("/manage/users/:id", handlers.ViewUser(), handlers.MiddlewareSession) + e.POST("/manage/users/create", handlers.CreateUser(), handlers.MiddlewareSession) e.GET("/logout", handlers.Logout()) e.POST("/logout", handlers.Logout()) diff --git a/handlers/config.go b/handlers/config.go index 1d1cf8d..622538e 100644 --- a/handlers/config.go +++ b/handlers/config.go @@ -5,17 +5,23 @@ package handlers import ( "git.dotya.ml/mirre-mt/pcmt/app/settings" + "git.dotya.ml/mirre-mt/pcmt/ent" "git.dotya.ml/mirre-mt/pcmt/slogging" "golang.org/x/exp/slog" ) var ( - setting *settings.Settings - appver string - slogger *slogging.Slogger - log slogging.Slogger + setting *settings.Settings + appver string + slogger *slogging.Slogger + log slogging.Slogger + dbclient *ent.Client ) +func SetDBClient(client *ent.Client) { + dbclient = client +} + func InitHandlers(s *settings.Settings) { slogger = slogging.Logger() log = *slogger // have a local copy. diff --git a/handlers/error.go b/handlers/error.go index b171f27..876994a 100644 --- a/handlers/error.go +++ b/handlers/error.go @@ -4,12 +4,18 @@ package handlers import ( + "errors" "fmt" "strconv" "github.com/labstack/echo/v4" ) +var ( + ErrNoSession = errors.New("No session found, please log in") + ErrSessionExpired = errors.New("Session expired, log in again") +) + func renderErrorPage(c echo.Context, status int, statusText, error string) error { addHeaders(c) diff --git a/handlers/manage-user.go b/handlers/manage-user.go index eb82508..46e3fd1 100644 --- a/handlers/manage-user.go +++ b/handlers/manage-user.go @@ -11,63 +11,39 @@ import ( "git.dotya.ml/mirre-mt/pcmt/ent" moduser "git.dotya.ml/mirre-mt/pcmt/modules/user" - "github.com/labstack/echo-contrib/session" + "github.com/gorilla/sessions" "github.com/labstack/echo/v4" ) -func ManageUsers(client *ent.Client) echo.HandlerFunc { //nolint:gocognit +func ManageUsers() echo.HandlerFunc { return func(c echo.Context) error { addHeaders(c) - sess, _ := session.Get(setting.SessionCookieName(), c) - if sess == nil { - c.Logger().Info("No session found, unauthorised.") - + u, ok := c.Get("sessUsr").(moduser.User) + if !ok { return renderErrorPage( c, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), - "No session found, please log in", + "it appears there is no user", ) } - uname := sess.Values["username"] - if uname == nil { - c.Logger().Debugf("%d - %s", http.StatusUnauthorized, "session expired or invalid") - - return renderErrorPage( - c, - http.StatusUnauthorized, - http.StatusText(http.StatusUnauthorized)+": Log in again", - "Session expired, log in again.", - ) - } - - log.Info("gorilla session", "username", sess.Values["username"].(string)) - - username := sess.Values["username"].(string) - - var u moduser.User - - ctx := context.WithValue(context.Background(), moduser.CtxKey{}, slogger) - if usr, err := moduser.QueryUser(ctx, client, username); err == nil && usr != nil { - u.ID = usr.ID - u.Username = usr.Username - u.IsAdmin = usr.IsAdmin - u.CreatedAt = usr.CreatedAt - u.IsActive = usr.IsActive - u.IsLoggedIn = true - } else { - c.Logger().Error(http.StatusText(http.StatusInternalServerError) + " - " + err.Error()) - + sess, ok := c.Get("sess").(*sessions.Session) + if !ok { return renderErrorPage( c, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError), - err.Error(), + "missing the session", ) } + ctx, ok := c.Get("sloggerCtx").(context.Context) + if !ok { + ctx = context.WithValue(context.Background(), moduser.CtxKey{}, slogger) + } + if !u.IsAdmin { c.Logger().Debug("this is a restricted endpoint") @@ -121,7 +97,7 @@ func ManageUsers(client *ent.Client) echo.HandlerFunc { //nolint:gocognit var allUsers []*moduser.User - if users, err := moduser.ListAll(ctx, client); err == nil && users != nil { + if users, err := moduser.ListAll(ctx, dbclient); err == nil && users != nil { for _, u := range users { usr := &moduser.User{ Username: u.Username, @@ -185,57 +161,23 @@ func ManageUsers(client *ent.Client) echo.HandlerFunc { //nolint:gocognit } } -func CreateUser(client *ent.Client) echo.HandlerFunc { //nolint:gocognit +func CreateUser() echo.HandlerFunc { //nolint:gocognit return func(c echo.Context) error { addHeaders(c) - sess, _ := session.Get(setting.SessionCookieName(), c) - if sess == nil { - c.Logger().Info("No session found, unauthorised.") - + u, ok := c.Get("sessUsr").(moduser.User) + if !ok { return renderErrorPage( c, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), - "No session found, please log in", + "username was nil", ) } - uname := sess.Values["username"] - if uname == nil { - c.Logger().Debugf("%d - %s", http.StatusUnauthorized, "session expired or invalid") - - return renderErrorPage( - c, - http.StatusUnauthorized, - http.StatusText(http.StatusUnauthorized)+": Log in again", - "Session expired, log in again.", - ) - } - - log.Info("gorilla session", "username", sess.Values["username"].(string)) - - username := sess.Values["username"].(string) - - var u moduser.User - - ctx := context.WithValue(context.Background(), moduser.CtxKey{}, slogger) - if usr, err := moduser.QueryUser(ctx, client, username); err == nil && usr != nil { - u.ID = usr.ID - u.Username = usr.Username - u.IsAdmin = usr.IsAdmin - u.CreatedAt = usr.CreatedAt - u.IsActive = usr.IsActive - u.IsLoggedIn = true - } else { - c.Logger().Error(http.StatusText(http.StatusInternalServerError) + " - " + err.Error()) - - return renderErrorPage( - c, - http.StatusInternalServerError, - http.StatusText(http.StatusInternalServerError), - err.Error(), - ) + ctx, ok := c.Get("sloggerCtx").(context.Context) + if !ok { + ctx = context.WithValue(context.Background(), moduser.CtxKey{}, slogger) } if !u.IsAdmin { @@ -249,6 +191,14 @@ func CreateUser(client *ent.Client) echo.HandlerFunc { //nolint:gocognit ) } + p := page{ + AppName: setting.AppName(), + AppVer: appver, + Title: "Manage Users", + DevelMode: setting.IsDevel(), + Current: "manage-user", + } + data := make(map[string]any) uc := new(userCreate) if err := c.Bind(uc); err != nil { @@ -269,33 +219,27 @@ func CreateUser(client *ent.Client) echo.HandlerFunc { //nolint:gocognit msg += "; password needs to be passed the same twice" } - data := make(map[string]any) - data["flash"] = msg data["form"] = uc + p.Data = data return c.Render( http.StatusBadRequest, "manage/user-new.tmpl", - page{ - AppName: setting.AppName(), - AppVer: appver, - Title: "Manage Users - New User", - DevelMode: setting.IsDevel(), - Current: "manage-user-new", - Data: data, - }, + p, ) } var msg string - usr, err := moduser.CreateUser(ctx, client, uc.Email, uc.Username, uc.Password, uc.IsAdmin) + usr, err := moduser.CreateUser(ctx, dbclient, uc.Email, uc.Username, uc.Password, uc.IsAdmin) if err == nil && usr != nil { msg = "created user '" + usr.Username + "'!" - sess.Values["flash"] = msg - _ = sess.Save(c.Request(), c.Response()) + if sess, ok := c.Get("sess").(*sessions.Session); ok { + sess.Values["flash"] = msg + _ = sess.Save(c.Request(), c.Response()) + } return c.Redirect(http.StatusSeeOther, "/manage/users") } @@ -308,98 +252,35 @@ func CreateUser(client *ent.Client) echo.HandlerFunc { //nolint:gocognit msg = "Error: " + err.Error() } - data := make(map[string]any) - data["flash"] = msg data["form"] = uc + p.Data = data return c.Render( http.StatusInternalServerError, "manage/user-new.tmpl", - page{ - AppName: setting.AppName(), - AppVer: appver, - Title: "Manage Users", - DevelMode: setting.IsDevel(), - Current: "manage-user", - Data: data, - }, + p, ) } } -func ViewUser(client *ent.Client) echo.HandlerFunc { +func ViewUser() echo.HandlerFunc { return func(c echo.Context) error { addHeaders(c) - sess, _ := session.Get(setting.SessionCookieName(), c) - if sess == nil { - c.Logger().Info("No session found, unauthorised.") - + u, ok := c.Get("sessUsr").(moduser.User) + if !ok { return renderErrorPage( c, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), - "No session found, please log in", + "username was nil", ) } - uname := sess.Values["username"] - if uname == nil { - c.Logger().Debugf("%d - %s", http.StatusUnauthorized, "session expired or invalid") - - return renderErrorPage( - c, - http.StatusUnauthorized, - http.StatusText(http.StatusUnauthorized)+": Log in again", - "Session expired, log in again.", - ) - } - - log.Info("gorilla session", "username", sess.Values["username"].(string)) - - username := sess.Values["username"].(string) - - var u moduser.User - - ctx := context.WithValue(context.Background(), moduser.CtxKey{}, slogger) - if usr, err := moduser.QueryUser(ctx, client, username); err == nil && usr != nil { - u.ID = usr.ID - u.Username = usr.Username - u.IsAdmin = usr.IsAdmin - u.CreatedAt = usr.CreatedAt - u.IsActive = usr.IsActive - u.IsLoggedIn = true - } else { - c.Logger().Error(http.StatusText(http.StatusInternalServerError) + " - " + err.Error()) - - return renderErrorPage( - c, - http.StatusInternalServerError, - http.StatusText(http.StatusInternalServerError), - err.Error(), - ) - } - - refreshSession( - sess, - "/", - // setting.SessionMaxAge, - 86400, - true, - c.Request().URL.Scheme == "https", //nolint:goconst - http.SameSiteStrictMode, - ) - - if err := sess.Save(c.Request(), c.Response()); err != nil { - c.Logger().Error("failed to save session") - - return renderErrorPage( - c, - http.StatusInternalServerError, - http.StatusText(http.StatusInternalServerError)+" (make sure you've got cookies enabled)", - err.Error(), - ) + ctx, ok := c.Get("sloggerCtx").(context.Context) + if !ok { + ctx = context.WithValue(context.Background(), moduser.CtxKey{}, slogger) } if !u.IsAdmin { @@ -426,7 +307,7 @@ func ViewUser(client *ent.Client) echo.HandlerFunc { err := c.Bind(uid) if err == nil { - usr, err := getUserByID(ctx, client, uid.ID) + usr, err := getUserByID(ctx, dbclient, uid.ID) if err != nil { if errors.Is(err, moduser.ErrUserNotFound) { //nolint:gocritic c.Logger().Errorf("user not found by ID: '%s'", uid.ID) diff --git a/handlers/middleware.go b/handlers/middleware.go new file mode 100644 index 0000000..11a1269 --- /dev/null +++ b/handlers/middleware.go @@ -0,0 +1,99 @@ +// Copyright 2023 wanderer +// SPDX-License-Identifier: AGPL-3.0-only + +package handlers + +import ( + "context" + "net/http" + + moduser "git.dotya.ml/mirre-mt/pcmt/modules/user" + "github.com/labstack/echo-contrib/session" + "github.com/labstack/echo/v4" +) + +func MiddlewareSession(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + sess, _ := session.Get(setting.SessionCookieName(), c) + if sess == nil { + c.Logger().Info("No session found, unauthorised.") + + // return a 404 instead of 401 to not disclose the existence of + // resources for unauthenticated users with no past sessions. + return echo.NewHTTPError(http.StatusNotFound).SetInternal(ErrNoSession) + } + + uname := sess.Values["username"] + if uname == nil { + c.Logger().Debugf("%d - %s", http.StatusUnauthorized, "seassion expired or invalid") + // return echo.NewHTTPError(http.StatusUnauthorized).SetInternal(ErrSessionExpired) + return renderErrorPage( + c, + http.StatusUnauthorized, + http.StatusText(http.StatusUnauthorized), + ErrSessionExpired.Error(), + ) + } + + username, ok := sess.Values["username"].(string) + if !ok { + return renderErrorPage( + c, + http.StatusUnauthorized, + http.StatusText(http.StatusUnauthorized), + "username was nil", + ) + } + + log.Info("gorilla session", "username", username) + + refreshSession( + sess, + "/", + // setting.SessionMaxAge, + 86400, + true, + c.Request().URL.Scheme == "https", //nolint:goconst + http.SameSiteStrictMode, + ) + + if err := sess.Save(c.Request(), c.Response()); err != nil { + c.Logger().Error("failed to save session") + + return renderErrorPage( + c, + http.StatusInternalServerError, + http.StatusText(http.StatusInternalServerError)+" (make sure you've got cookies enabled)", + err.Error(), + ) + } + + c.Set("sess", sess) + + var u moduser.User + + ctx := context.WithValue(context.Background(), moduser.CtxKey{}, slogger) + if usr, err := moduser.QueryUser(ctx, dbclient, username); err == nil && usr != nil { + u.ID = usr.ID + u.Username = usr.Username + u.IsAdmin = usr.IsAdmin + u.CreatedAt = usr.CreatedAt + u.IsActive = usr.IsActive + u.IsLoggedIn = true + } else { + c.Logger().Error(http.StatusText(http.StatusInternalServerError) + " - " + err.Error()) + + return renderErrorPage( + c, + http.StatusInternalServerError, + http.StatusText(http.StatusInternalServerError), + err.Error(), + ) + } + + c.Set("sloggerCtx", ctx) + c.Set("sessUsr", u) + + return next(c) + } +}