594 lines
14 KiB
Go
594 lines
14 KiB
Go
// Copyright 2023 wanderer <a_mirre at utb dot cz>
|
|
// SPDX-License-Identifier: AGPL-3.0-only
|
|
|
|
package user
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"git.dotya.ml/mirre-mt/pcmt/ent"
|
|
"git.dotya.ml/mirre-mt/pcmt/ent/user"
|
|
passwd "git.dotya.ml/mirre-mt/pcmt/modules/password"
|
|
"git.dotya.ml/mirre-mt/pcmt/slogging"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/exp/slog"
|
|
)
|
|
|
|
type User struct {
|
|
ID uuid.UUID
|
|
Username string
|
|
Email string
|
|
IsActive bool
|
|
IsAdmin bool
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
IsLoggedIn bool
|
|
LastLogin time.Time
|
|
}
|
|
|
|
// CreateUser adds a user entry to the database.
|
|
func CreateUser(ctx context.Context, client *ent.Client, email, username, password string, isAdmin ...bool) (*ent.User, error) {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
digest, err := passwd.GetHash(password)
|
|
if err != nil {
|
|
log.Errorf("error hashing password: %s", err)
|
|
return nil, errors.New("could not hash password")
|
|
}
|
|
|
|
var admin bool
|
|
|
|
// if set, the first of the array is the arg.
|
|
if len(isAdmin) != 0 {
|
|
admin = isAdmin[0]
|
|
}
|
|
|
|
u, err := client.User.
|
|
Create().
|
|
SetEmail(email).
|
|
SetUsername(username).
|
|
SetPassword(digest).
|
|
SetIsAdmin(admin).
|
|
Save(ctx)
|
|
|
|
switch {
|
|
case ent.IsConstraintError(err):
|
|
log.Errorf("the username '%s' already exists", username)
|
|
return nil, errors.New("username is not unique")
|
|
|
|
case err != nil:
|
|
return nil, fmt.Errorf("failed creating user: %w", err)
|
|
}
|
|
|
|
log.Debug(
|
|
fmt.Sprintf(
|
|
"user successfully created - id: %s, name: %s, isActive: %t, isAdmin: %t, createdAt: %s, updatedAt: %s",
|
|
u.ID, u.Username, u.IsActive, u.IsAdmin, u.CreatedAt, u.UpdatedAt,
|
|
),
|
|
)
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func QueryUser(ctx context.Context, client *ent.Client, username string) (*ent.User, error) {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
u, err := client.User.
|
|
Query().
|
|
Where(user.Username(username)).
|
|
// `Only` fails if no user found,
|
|
// or more than 1 user returned.
|
|
Only(ctx)
|
|
|
|
switch {
|
|
case ent.IsNotFound(err):
|
|
return nil, fmt.Errorf("user not found: %q", err)
|
|
|
|
case err != nil:
|
|
log.Warn("error querying user", "error", err, "username requested", username)
|
|
return nil, fmt.Errorf("failed querying user: %w", err)
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
// QueryUserByID returns user for the provided ID, and nil if err == nil, nil
|
|
// and err otherwise.
|
|
func QueryUserByID(ctx context.Context, client *ent.Client, strID string) (*ent.User, error) {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
id, err := uuid.Parse(strID)
|
|
if err != nil {
|
|
return nil, ErrBadUUID
|
|
}
|
|
|
|
return QueryUserByUUID(ctx, client, id)
|
|
}
|
|
|
|
// QueryUserByUUID returns user for the provided ID, and nil if err == nil, nil
|
|
// and err otherwise.
|
|
func QueryUserByUUID(ctx context.Context, client *ent.Client, id uuid.UUID) (*ent.User, error) {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
u, err := client.User.
|
|
Query().
|
|
Where(user.IDEQ(id)).
|
|
// `Only` fails if no user found,
|
|
// or more than 1 user returned.
|
|
Only(ctx)
|
|
|
|
switch {
|
|
case ent.IsNotFound(err):
|
|
log.Warnf("user not found by ID: %q", err)
|
|
return nil, ErrUserNotFound
|
|
|
|
case err != nil:
|
|
log.Warn("failed to query user by ID", "error", err, "uuid requested", id)
|
|
return nil, ErrFailedToQueryUser
|
|
}
|
|
|
|
return u, nil
|
|
}
|
|
|
|
func UsrFinishedSetup(ctx context.Context, client *ent.Client, id uuid.UUID) (bool, error) {
|
|
u, err := client.User.Get(ctx, id)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
if u.LastLogin.After(time.Unix(0, 0)) {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func ChangePassFirstLogin(ctx context.Context, client *ent.Client, id uuid.UUID, password string) error {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
finishedSetup, err := UsrFinishedSetup(ctx, client, id)
|
|
|
|
switch {
|
|
case err != nil:
|
|
return err
|
|
|
|
case finishedSetup:
|
|
return nil
|
|
}
|
|
|
|
if password == "" {
|
|
return ErrPasswordEmpty
|
|
}
|
|
|
|
{
|
|
u, err := QueryUserByUUID(ctx, client, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
equal := passwd.Compare(u.Password, password)
|
|
if equal {
|
|
return ErrNewPasswordCannotEqual
|
|
}
|
|
}
|
|
|
|
var digest []byte
|
|
|
|
digest, err = passwd.GetHash(password)
|
|
if err != nil {
|
|
log.Errorf("error hashing password: %s", err)
|
|
return errors.New("could not hash password")
|
|
}
|
|
|
|
// TODO: turn this into a transaction.
|
|
u, err := client.User.
|
|
Update().Where(user.IDEQ(id)).
|
|
SetPassword(digest).
|
|
Save(ctx)
|
|
|
|
switch {
|
|
case err != nil:
|
|
return fmt.Errorf("failed to update user password: %w", err)
|
|
|
|
case u > 1:
|
|
return fmt.Errorf("somehow updated password of more than one user? count: %d", u)
|
|
}
|
|
|
|
u, err = client.User.
|
|
Update().Where(user.IDEQ(id)).
|
|
SetLastLogin(time.Now()).
|
|
Save(ctx)
|
|
|
|
switch {
|
|
case err != nil:
|
|
return fmt.Errorf("failed to set last login: %w", err)
|
|
|
|
case u > 1:
|
|
return fmt.Errorf("somehow updated last login information of more than one user? count: %d", u)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func UpdateUserByAdmin(ctx context.Context, client *ent.Client, id uuid.UUID, email, username, password string, isAdmin bool, isActive *bool) error {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
finishedSetup, err := UsrFinishedSetup(ctx, client, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var active bool
|
|
|
|
if isActive != nil {
|
|
active = *isActive
|
|
}
|
|
|
|
var u int
|
|
|
|
switch {
|
|
// ignore updates to password when user finished setting up (if not admin).
|
|
case !isAdmin && finishedSetup:
|
|
u, err = client.User.
|
|
Update().Where(user.IDEQ(id)).
|
|
SetEmail(email).
|
|
SetUsername(username).
|
|
SetIsAdmin(isAdmin).
|
|
SetIsActive(active).
|
|
Save(ctx)
|
|
|
|
default:
|
|
var digest []byte
|
|
|
|
if digest, err = passwd.GetHash(password); err != nil {
|
|
log.Errorf("error hashing password: %s", err)
|
|
return errors.New("could not hash password")
|
|
}
|
|
|
|
var origU *ent.User
|
|
|
|
if origU, err = QueryUserByUUID(ctx, client, id); err != nil {
|
|
return err
|
|
}
|
|
|
|
// handle a situation when an admin account is demoted to a
|
|
// regular-user level. reset last-login so as to force the user to go
|
|
// through the initial password change flow.
|
|
if origU.IsAdmin && !isAdmin {
|
|
u, err = client.User.
|
|
Update().Where(user.IDEQ(id)).
|
|
SetEmail(email).
|
|
SetUsername(username).
|
|
SetPassword(digest).
|
|
SetIsAdmin(isAdmin).
|
|
SetIsActive(active).
|
|
SetLastLogin(time.Unix(0, 0)).
|
|
Save(ctx)
|
|
} else {
|
|
u, err = client.User.
|
|
Update().Where(user.IDEQ(id)).
|
|
SetEmail(email).
|
|
SetUsername(username).
|
|
SetPassword(digest).
|
|
SetIsAdmin(isAdmin).
|
|
SetIsActive(active).
|
|
Save(ctx)
|
|
}
|
|
}
|
|
|
|
switch {
|
|
case ent.IsConstraintError(err):
|
|
log.Errorf("the username '%s' already exists", username)
|
|
return errors.New("username is not unique")
|
|
|
|
case err != nil:
|
|
return fmt.Errorf("failed to update user: %w", err)
|
|
|
|
case u > 1:
|
|
return fmt.Errorf("somehow updated more than one user? count: %d", u)
|
|
}
|
|
|
|
log.Debug(
|
|
fmt.Sprintf(
|
|
"user successfully updated - id: %s, name: %s, active: %t, admin: %t",
|
|
id, username, *isActive, isAdmin,
|
|
),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateUserLastLogin serves to update the last_login param of the user. This
|
|
// parameter will not get updated for users that never finished setting up,
|
|
// return nil on success and error on err.
|
|
func UpdateUserLastLogin(ctx context.Context, client *ent.Client, id uuid.UUID, isAdmin bool) error {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
finishedSetup, err := UsrFinishedSetup(ctx, client, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !isAdmin && !finishedSetup {
|
|
return ErrUnfinishedSetupLastLoginUpdate
|
|
}
|
|
|
|
u, err := client.User.
|
|
Update().Where(user.IDEQ(id)).
|
|
SetLastLogin(time.Now()).
|
|
Save(ctx)
|
|
|
|
switch {
|
|
case err != nil:
|
|
return fmt.Errorf("failed to update last_login for user: %w", err)
|
|
|
|
case u > 1:
|
|
return fmt.Errorf("somehow updated last_login for more than one user? count: %d", u)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteUserByID returns nil on successful deletion, err otherwise.
|
|
func DeleteUserByID(ctx context.Context, client *ent.Client, strID string) error {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
id, err := uuid.Parse(strID)
|
|
if err != nil {
|
|
return ErrBadUUID
|
|
}
|
|
|
|
err = client.User.
|
|
DeleteOneID(id).Exec(ctx)
|
|
|
|
switch {
|
|
case ent.IsNotFound(err):
|
|
log.Warnf("user not found by ID: %q", err)
|
|
return ErrUserNotFound
|
|
|
|
case err != nil:
|
|
log.Warn("failed to query user by ID", "error", err, "uuid requested", id)
|
|
return ErrFailedToQueryUser
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Exists determines whether the username OR email in question was previously
|
|
// used to register a user.
|
|
func Exists(ctx context.Context, client *ent.Client, username, email string) (bool, error) {
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
usernameExists, err := UsernameExists(ctx, client, username)
|
|
if err != nil {
|
|
log.Warn("failed to check whether username is taken", "error", err, "username requested", username)
|
|
|
|
return false, ErrFailedToQueryUser
|
|
}
|
|
|
|
emailExists, err := EmailExists(ctx, client, email)
|
|
if err != nil {
|
|
log.Warn("failed to check whether user email exists", "error", err, "user email requested", email)
|
|
|
|
return false, ErrFailedToQueryUser
|
|
}
|
|
|
|
switch {
|
|
case usernameExists && emailExists:
|
|
log.Infof("user exists: both username '%s' and email: '%s' matched", username, email)
|
|
return true, nil
|
|
|
|
case usernameExists:
|
|
log.Infof("username '%s' is already taken", username)
|
|
return true, nil
|
|
|
|
case emailExists:
|
|
log.Infof("email '%s' is already registered", email)
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// UsernameExists queries the database to check whether the username is
|
|
// available or taken, returns a bool and an error, which will be nil unless
|
|
// the error is one of IsNot{Found,Singular}.
|
|
func UsernameExists(ctx context.Context, client *ent.Client, username string) (bool, error) { //nolint:dupl
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
usr, err := client.User.
|
|
Query().
|
|
Where(user.Username(username)).
|
|
Only(ctx)
|
|
|
|
switch {
|
|
case ent.IsNotFound(err):
|
|
log.Infof("username '%s' not found", username)
|
|
return false, nil
|
|
|
|
case ent.IsNotSingular(err):
|
|
log.Errorf("apparently more than one user managed to acquire the username '%s', bailing", username)
|
|
return true, err
|
|
|
|
case err != nil:
|
|
log.Warn("failed to check whether user exists", "error", err, "username queried", username)
|
|
return false, fmt.Errorf("failed querying username: %w", err)
|
|
}
|
|
|
|
if usr != nil {
|
|
log.Infof("username '%s' found, user id: %s", username, usr.ID)
|
|
return true, nil
|
|
}
|
|
|
|
log.Warn("we should not have gotten here, apparently error was nil but so was usr...")
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// EmailExists queries the database to check whether the email was already
|
|
// used; returns a bool and an error, which will be nil unless the error is not
|
|
// one of IsNot{Found,Singular}.
|
|
func EmailExists(ctx context.Context, client *ent.Client, email string) (bool, error) { //nolint:dupl
|
|
slogger := ctx.Value(CtxKey{}).(*slogging.Slogger)
|
|
log := *slogger
|
|
|
|
log.Logger = log.Logger.With(
|
|
slog.Group("pcmt extra", slog.String("module", "modules/user")),
|
|
)
|
|
|
|
usr, err := client.User.
|
|
Query().
|
|
Where(user.Email(email)).
|
|
Only(ctx)
|
|
|
|
switch {
|
|
case ent.IsNotFound(err):
|
|
log.Infof("user email '%s' not found", email)
|
|
return false, nil
|
|
|
|
case ent.IsNotSingular(err):
|
|
log.Errorf("apparently more than one user managed to register using the email '%s', bailing", email)
|
|
return true, err
|
|
|
|
case err != nil:
|
|
log.Warn("failed to check whether email exists", "error", err, "email queried", email)
|
|
return false, fmt.Errorf("failed querying user email: %w", err)
|
|
}
|
|
|
|
if usr != nil {
|
|
log.Infof("user email '%s' found, user id: %s", email, usr.ID)
|
|
return true, nil
|
|
}
|
|
|
|
log.Warn("we should not have gotten here, apparently error was nil but so was usr...")
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func ListAll(ctx context.Context, client *ent.Client) ([]*ent.User, error) {
|
|
users, err := client.User.
|
|
Query().All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func ListAllRegular(ctx context.Context, client *ent.Client) ([]*ent.User, error) {
|
|
users, err := client.User.
|
|
Query().
|
|
Where(user.IsAdminEQ(false)).
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
func ListAllAdmins(ctx context.Context, client *ent.Client) ([]*ent.User, error) {
|
|
admins, err := client.User.
|
|
Query().
|
|
Where(user.IsAdminEQ(true)).
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return admins, nil
|
|
}
|
|
|
|
// NoUsers checks whether there are any users at all in the db.
|
|
func NoUsers(ctx context.Context, client *ent.Client) (bool, error) {
|
|
count, err := client.User.
|
|
Query().
|
|
Count(ctx)
|
|
if err != nil {
|
|
return false, nil
|
|
}
|
|
|
|
if count > 0 {
|
|
return false, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// CreateFirst creates the first user and makes them an administrator.
|
|
// To be used during app setup.
|
|
func CreateFirst(ctx context.Context, client *ent.Client, username, email, password string) error {
|
|
noUsers, err := NoUsers(ctx, client)
|
|
|
|
switch {
|
|
case err != nil:
|
|
return err
|
|
|
|
case noUsers:
|
|
_, err := CreateUser(ctx, client, email, username, password, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
|
|
case !noUsers:
|
|
return ErrUsersAlreadyPresent
|
|
}
|
|
|
|
return err
|
|
}
|