go: add user email handling + improve sessions
All checks were successful
continuous-integration/drone/push Build is passing

* add Email field to User entity (+codegen)
* switch to Echo contrib sessions backed by Gorilla sessions
* use SessionCookieSecret value from the config
* stage mod,sum changes
* add clearer signup/signin redirect logic
* render error pages on 500/404s and only fall back to returning raw
  errors when the error is some unexpected kind
* add username/email "exists" funcs+tests - handle "not found" and "not
  unique" errors, return errors otherwise
This commit is contained in:
leo 2023-05-01 22:48:11 +02:00
parent 61ec8bfea1
commit 593454d616
Signed by: wanderer
SSH Key Fingerprint: SHA256:Dp8+iwKHSlrMEHzE3bJnPng70I7LEsa3IJXRH/U+idQ
15 changed files with 655 additions and 198 deletions

@ -3,6 +3,8 @@ package app
import (
"net/http"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4/middleware"
)
@ -37,6 +39,8 @@ func (a *App) SetEchoSettings() {
e.Use(middleware.Recover())
e.Use(session.Middleware(sessions.NewCookieStore([]byte(a.config.SessionCookieSecret))))
// e.Use(middleware.CSRF())
e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
TokenLookup: "cookie:_csrf",

@ -37,6 +37,7 @@ var (
UsersColumns = []*schema.Column{
{Name: "id", Type: field.TypeUUID, Unique: true},
{Name: "username", Type: field.TypeString, Unique: true},
{Name: "email", Type: field.TypeString, Unique: true},
{Name: "password", Type: field.TypeString},
{Name: "is_admin", Type: field.TypeBool, Default: false},
{Name: "is_active", Type: field.TypeBool, Default: true},

@ -1177,6 +1177,7 @@ type UserMutation struct {
typ string
id *uuid.UUID
username *string
email *string
password *string
is_admin *bool
is_active *bool
@ -1328,6 +1329,42 @@ func (m *UserMutation) ResetUsername() {
m.username = nil
}
// SetEmail sets the "email" field.
func (m *UserMutation) SetEmail(s string) {
m.email = &s
}
// Email returns the value of the "email" field in the mutation.
func (m *UserMutation) Email() (r string, exists bool) {
v := m.email
if v == nil {
return
}
return *v, true
}
// OldEmail returns the old "email" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldEmail(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldEmail is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldEmail requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldEmail: %w", err)
}
return oldValue.Email, nil
}
// ResetEmail resets all changes to the "email" field.
func (m *UserMutation) ResetEmail() {
m.email = nil
}
// SetPassword sets the "password" field.
func (m *UserMutation) SetPassword(s string) {
m.password = &s
@ -1542,10 +1579,13 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 6)
fields := make([]string, 0, 7)
if m.username != nil {
fields = append(fields, user.FieldUsername)
}
if m.email != nil {
fields = append(fields, user.FieldEmail)
}
if m.password != nil {
fields = append(fields, user.FieldPassword)
}
@ -1571,6 +1611,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
switch name {
case user.FieldUsername:
return m.Username()
case user.FieldEmail:
return m.Email()
case user.FieldPassword:
return m.Password()
case user.FieldIsAdmin:
@ -1592,6 +1634,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
switch name {
case user.FieldUsername:
return m.OldUsername(ctx)
case user.FieldEmail:
return m.OldEmail(ctx)
case user.FieldPassword:
return m.OldPassword(ctx)
case user.FieldIsAdmin:
@ -1618,6 +1662,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetUsername(v)
return nil
case user.FieldEmail:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetEmail(v)
return nil
case user.FieldPassword:
v, ok := value.(string)
if !ok {
@ -1705,6 +1756,9 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldUsername:
m.ResetUsername()
return nil
case user.FieldEmail:
m.ResetEmail()
return nil
case user.FieldPassword:
m.ResetPassword()
return nil

@ -47,24 +47,28 @@ func init() {
userDescUsername := userFields[1].Descriptor()
// user.UsernameValidator is a validator for the "username" field. It is called by the builders before save.
user.UsernameValidator = userDescUsername.Validators[0].(func(string) error)
// userDescEmail is the schema descriptor for email field.
userDescEmail := userFields[2].Descriptor()
// user.EmailValidator is a validator for the "email" field. It is called by the builders before save.
user.EmailValidator = userDescEmail.Validators[0].(func(string) error)
// userDescPassword is the schema descriptor for password field.
userDescPassword := userFields[2].Descriptor()
userDescPassword := userFields[3].Descriptor()
// user.PasswordValidator is a validator for the "password" field. It is called by the builders before save.
user.PasswordValidator = userDescPassword.Validators[0].(func(string) error)
// userDescIsAdmin is the schema descriptor for is_admin field.
userDescIsAdmin := userFields[3].Descriptor()
userDescIsAdmin := userFields[4].Descriptor()
// user.DefaultIsAdmin holds the default value on creation for the is_admin field.
user.DefaultIsAdmin = userDescIsAdmin.Default.(bool)
// userDescIsActive is the schema descriptor for is_active field.
userDescIsActive := userFields[4].Descriptor()
userDescIsActive := userFields[5].Descriptor()
// user.DefaultIsActive holds the default value on creation for the is_active field.
user.DefaultIsActive = userDescIsActive.Default.(bool)
// userDescCreatedAt is the schema descriptor for created_at field.
userDescCreatedAt := userFields[5].Descriptor()
userDescCreatedAt := userFields[6].Descriptor()
// user.DefaultCreatedAt holds the default value on creation for the created_at field.
user.DefaultCreatedAt = userDescCreatedAt.Default.(func() time.Time)
// userDescUpdatedAt is the schema descriptor for updated_at field.
userDescUpdatedAt := userFields[6].Descriptor()
userDescUpdatedAt := userFields[7].Descriptor()
// user.DefaultUpdatedAt holds the default value on creation for the updated_at field.
user.DefaultUpdatedAt = userDescUpdatedAt.Default.(func() time.Time)
// user.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.

@ -23,6 +23,9 @@ func (User) Fields() []ent.Field {
field.String("username").
NotEmpty().
Unique(),
field.String("email").
NotEmpty().
Unique(),
field.String("password").
Sensitive().
NotEmpty(),

@ -19,6 +19,8 @@ type User struct {
ID uuid.UUID `json:"id,omitempty"`
// Username holds the value of the "username" field.
Username string `json:"username,omitempty"`
// Email holds the value of the "email" field.
Email string `json:"email,omitempty"`
// Password holds the value of the "password" field.
Password string `json:"-"`
// IsAdmin holds the value of the "is_admin" field.
@ -38,7 +40,7 @@ func (*User) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case user.FieldIsAdmin, user.FieldIsActive:
values[i] = new(sql.NullBool)
case user.FieldUsername, user.FieldPassword:
case user.FieldUsername, user.FieldEmail, user.FieldPassword:
values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@ -71,6 +73,12 @@ func (u *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
u.Username = value.String
}
case user.FieldEmail:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field email", values[i])
} else if value.Valid {
u.Email = value.String
}
case user.FieldPassword:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field password", values[i])
@ -132,6 +140,9 @@ func (u *User) String() string {
builder.WriteString("username=")
builder.WriteString(u.Username)
builder.WriteString(", ")
builder.WriteString("email=")
builder.WriteString(u.Email)
builder.WriteString(", ")
builder.WriteString("password=<sensitive>")
builder.WriteString(", ")
builder.WriteString("is_admin=")

@ -15,6 +15,8 @@ const (
FieldID = "id"
// FieldUsername holds the string denoting the username field in the database.
FieldUsername = "username"
// FieldEmail holds the string denoting the email field in the database.
FieldEmail = "email"
// FieldPassword holds the string denoting the password field in the database.
FieldPassword = "password"
// FieldIsAdmin holds the string denoting the is_admin field in the database.
@ -33,6 +35,7 @@ const (
var Columns = []string{
FieldID,
FieldUsername,
FieldEmail,
FieldPassword,
FieldIsAdmin,
FieldIsActive,
@ -53,6 +56,8 @@ func ValidColumn(column string) bool {
var (
// UsernameValidator is a validator for the "username" field. It is called by the builders before save.
UsernameValidator func(string) error
// EmailValidator is a validator for the "email" field. It is called by the builders before save.
EmailValidator func(string) error
// PasswordValidator is a validator for the "password" field. It is called by the builders before save.
PasswordValidator func(string) error
// DefaultIsAdmin holds the default value on creation for the "is_admin" field.

@ -60,6 +60,11 @@ func Username(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldUsername, v))
}
// Email applies equality check predicate on the "email" field. It's identical to EmailEQ.
func Email(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v))
}
// Password applies equality check predicate on the "password" field. It's identical to PasswordEQ.
func Password(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldPassword, v))
@ -150,6 +155,71 @@ func UsernameContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldUsername, v))
}
// EmailEQ applies the EQ predicate on the "email" field.
func EmailEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldEmail, v))
}
// EmailNEQ applies the NEQ predicate on the "email" field.
func EmailNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldEmail, v))
}
// EmailIn applies the In predicate on the "email" field.
func EmailIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldEmail, vs...))
}
// EmailNotIn applies the NotIn predicate on the "email" field.
func EmailNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldEmail, vs...))
}
// EmailGT applies the GT predicate on the "email" field.
func EmailGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldEmail, v))
}
// EmailGTE applies the GTE predicate on the "email" field.
func EmailGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldEmail, v))
}
// EmailLT applies the LT predicate on the "email" field.
func EmailLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldEmail, v))
}
// EmailLTE applies the LTE predicate on the "email" field.
func EmailLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldEmail, v))
}
// EmailContains applies the Contains predicate on the "email" field.
func EmailContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldEmail, v))
}
// EmailHasPrefix applies the HasPrefix predicate on the "email" field.
func EmailHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldEmail, v))
}
// EmailHasSuffix applies the HasSuffix predicate on the "email" field.
func EmailHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldEmail, v))
}
// EmailEqualFold applies the EqualFold predicate on the "email" field.
func EmailEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldEmail, v))
}
// EmailContainsFold applies the ContainsFold predicate on the "email" field.
func EmailContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldEmail, v))
}
// PasswordEQ applies the EQ predicate on the "password" field.
func PasswordEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldPassword, v))

@ -27,6 +27,12 @@ func (uc *UserCreate) SetUsername(s string) *UserCreate {
return uc
}
// SetEmail sets the "email" field.
func (uc *UserCreate) SetEmail(s string) *UserCreate {
uc.mutation.SetEmail(s)
return uc
}
// SetPassword sets the "password" field.
func (uc *UserCreate) SetPassword(s string) *UserCreate {
uc.mutation.SetPassword(s)
@ -170,6 +176,14 @@ func (uc *UserCreate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if _, ok := uc.mutation.Email(); !ok {
return &ValidationError{Name: "email", err: errors.New(`ent: missing required field "User.email"`)}
}
if v, ok := uc.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
}
}
if _, ok := uc.mutation.Password(); !ok {
return &ValidationError{Name: "password", err: errors.New(`ent: missing required field "User.password"`)}
}
@ -229,6 +243,10 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldUsername, field.TypeString, value)
_node.Username = value
}
if value, ok := uc.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value)
_node.Email = value
}
if value, ok := uc.mutation.Password(); ok {
_spec.SetField(user.FieldPassword, field.TypeString, value)
_node.Password = value

@ -34,6 +34,12 @@ func (uu *UserUpdate) SetUsername(s string) *UserUpdate {
return uu
}
// SetEmail sets the "email" field.
func (uu *UserUpdate) SetEmail(s string) *UserUpdate {
uu.mutation.SetEmail(s)
return uu
}
// SetPassword sets the "password" field.
func (uu *UserUpdate) SetPassword(s string) *UserUpdate {
uu.mutation.SetPassword(s)
@ -122,6 +128,11 @@ func (uu *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uu.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
}
}
if v, ok := uu.mutation.Password(); ok {
if err := user.PasswordValidator(v); err != nil {
return &ValidationError{Name: "password", err: fmt.Errorf(`ent: validator failed for field "User.password": %w`, err)}
@ -145,6 +156,9 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) {
if value, ok := uu.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uu.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value)
}
if value, ok := uu.mutation.Password(); ok {
_spec.SetField(user.FieldPassword, field.TypeString, value)
}
@ -183,6 +197,12 @@ func (uuo *UserUpdateOne) SetUsername(s string) *UserUpdateOne {
return uuo
}
// SetEmail sets the "email" field.
func (uuo *UserUpdateOne) SetEmail(s string) *UserUpdateOne {
uuo.mutation.SetEmail(s)
return uuo
}
// SetPassword sets the "password" field.
func (uuo *UserUpdateOne) SetPassword(s string) *UserUpdateOne {
uuo.mutation.SetPassword(s)
@ -284,6 +304,11 @@ func (uuo *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
if v, ok := uuo.mutation.Email(); ok {
if err := user.EmailValidator(v); err != nil {
return &ValidationError{Name: "email", err: fmt.Errorf(`ent: validator failed for field "User.email": %w`, err)}
}
}
if v, ok := uuo.mutation.Password(); ok {
if err := user.PasswordValidator(v); err != nil {
return &ValidationError{Name: "password", err: fmt.Errorf(`ent: validator failed for field "User.password": %w`, err)}
@ -324,6 +349,9 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error)
if value, ok := uuo.mutation.Username(); ok {
_spec.SetField(user.FieldUsername, field.TypeString, value)
}
if value, ok := uuo.mutation.Email(); ok {
_spec.SetField(user.FieldEmail, field.TypeString, value)
}
if value, ok := uuo.mutation.Password(); ok {
_spec.SetField(user.FieldPassword, field.TypeString, value)
}

8
go.mod

@ -5,6 +5,8 @@ go 1.20
require (
entgo.io/ent v0.11.10
github.com/google/uuid v1.3.0
github.com/gorilla/sessions v1.2.1
github.com/labstack/echo-contrib v0.14.1
github.com/labstack/echo/v4 v4.10.2
github.com/microcosm-cc/bluemonday v1.0.23
github.com/philandstuff/dhall-golang/v6 v6.0.2
@ -22,17 +24,18 @@ require (
github.com/fxamacker/cbor/v2 v2.2.1-0.20200511212021-28e39be4a84f // indirect
github.com/go-openapi/inflect v0.19.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/css v1.0.0 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/hashicorp/hcl/v2 v2.13.0 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/labstack/gommon v0.4.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rogpeppe/go-internal v1.10.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
github.com/x448/float16 v0.8.4 // indirect
@ -43,7 +46,6 @@ require (
golang.org/x/text v0.8.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.6.1-0.20230222164832-25d2519c8696 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
lukechampine.com/uint128 v1.2.0 // indirect
modernc.org/cc/v3 v3.40.0 // indirect

20
go.sum

@ -27,30 +27,35 @@ github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzq
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.4/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc=
github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs=
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4=
github.com/labstack/echo-contrib v0.14.1 h1:oNUSCeXQOlCGt3eWafzu0mkXjIh3SINnYgE/UR2kYXQ=
github.com/labstack/echo-contrib v0.14.1/go.mod h1:6jgpHPjGRk0qrysPCfv3SCau6kewjQtYzOk1fLZGMeQ=
github.com/labstack/echo/v4 v4.10.2 h1:n1jAhnq/elIFTHr1EYpiYtyKgx4RW9ccVgkqByZaN2M=
github.com/labstack/echo/v4 v4.10.2/go.mod h1:OEyqf2//K1DFdE57vw2DRgWY0M7s65IVQO2FzvI4J5k=
github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8=
@ -78,12 +83,13 @@ github.com/philandstuff/dhall-golang/v6 v6.0.2 h1:jv8fi4ZYiFe6uGrprx6dY7L3xPcgmE
github.com/philandstuff/dhall-golang/v6 v6.0.2/go.mod h1:XRoxjsqZM2y7KPFhjV7CSVdWpV5CwuTzGjAY/v+1SUU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
@ -139,12 +145,8 @@ golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.6.1-0.20230222164832-25d2519c8696 h1:8985/C5IvACpd9DDXckSnjSBLKDgbxXiyODgi94zOPM=
golang.org/x/tools v0.6.1-0.20230222164832-25d2519c8696/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=

@ -7,10 +7,13 @@ import (
"io/fs"
"net/http"
"path/filepath"
"strconv"
"strings"
"git.dotya.ml/mirre-mt/pcmt/ent"
moduser "git.dotya.ml/mirre-mt/pcmt/modules/user"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/microcosm-cc/bluemonday"
)
@ -139,26 +142,16 @@ func Index() echo.HandlerFunc {
func Signin() echo.HandlerFunc {
return func(c echo.Context) error {
session, err := c.Cookie("session")
if err != nil {
if err == http.ErrNoCookie {
log.Info("no session cookie found")
}
}
sess, _ := session.Get(conf.SessionCookieName, c)
if session != nil {
log.Info("got session")
// if err := session.Valid(); err == nil && session.Expires.After(time.Now()) {
if err := session.Valid(); err == nil {
return c.Redirect(302, "/home")
}
log.Warn("invalid (or expired) session", "error", err.Error())
username := sess.Values["username"]
if username != nil {
return c.Redirect(http.StatusFound, "/home")
}
tpl := getTmpl("signin.tmpl")
err = tpl.Execute(c.Response().Writer,
err := tpl.Execute(c.Response().Writer,
page{
AppName: conf.AppName,
AppVer: appver,
@ -192,7 +185,7 @@ func SigninPost(client *ent.Client) echo.HandlerFunc {
} else {
log.Info("username was not set, returning to /signin")
return c.Redirect(302, "/signin")
return c.Redirect(http.StatusFound, "/signin")
}
if passwd := c.Request().FormValue("password"); passwd != "" {
@ -200,7 +193,7 @@ func SigninPost(client *ent.Client) echo.HandlerFunc {
} else {
log.Info("password was not set, returning to /signin")
return c.Redirect(302, "/signin")
return c.Redirect(http.StatusFound, "/signin")
}
ctx := context.WithValue(context.Background(), moduser.CtxKey{}, log)
@ -210,52 +203,68 @@ func SigninPost(client *ent.Client) echo.HandlerFunc {
if usr.Password != password {
log.Warn("wrong user credentials, redirecting to /signin")
return c.Redirect(302, "/signin")
return c.Redirect(http.StatusFound, "/signin")
}
} else {
if ent.IsNotFound(err) {
c.Logger().Error("user not found")
return c.Redirect(http.StatusFound, "/signin")
}
// just log the error instead of returning it to the user and
// redirect back to /signin.
log.Warn(
c.Logger().Error(
http.StatusText(http.StatusUnauthorized)+" "+err.Error(),
echo.NewHTTPError(
http.StatusUnauthorized,
http.StatusText(http.StatusUnauthorized)+" "+err.Error(),
))
strconv.Itoa(http.StatusUnauthorized)+" "+http.StatusText(http.StatusUnauthorized)+" "+err.Error(),
)
return c.Redirect(302, "/signin")
return c.Redirect(http.StatusFound, "/signin")
}
secure := c.Request().URL.Scheme == "https" //nolint:goconst
cookieSession := &http.Cookie{
Name: "session",
Value: username,
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
Secure: secure,
HttpOnly: true,
}
c.SetCookie(cookieSession)
sess, _ := session.Get(conf.SessionCookieName, c)
if sess != nil {
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 3600,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteStrictMode,
}
sess.Values["foo"] = "bar"
sess.Values["username"] = username
return c.Redirect(301, "/home")
err := sess.Save(c.Request(), c.Response())
if err != nil {
c.Logger().Error("failed to save session")
err = renderErrorPage(
c.Response().Writer,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError)+" (make sure you've got cookies enabled)",
err.Error(),
)
if err != nil {
return err
}
}
}
return c.Redirect(http.StatusMovedPermanently, "/home")
}
}
func Signup() echo.HandlerFunc {
return func(c echo.Context) error {
session, err := c.Cookie("session")
if err != nil && err == http.ErrNoCookie {
log.Info("no session cookie found")
}
sess, _ := session.Get(conf.SessionCookieName, c)
if sess != nil {
log.Info("gorilla session", "endpoint", "signup")
if session != nil {
log.Info("got session")
// if err := session.Valid(); err == nil && session.Expires.After(time.Now()) {
if err := session.Valid(); err == nil {
return c.Redirect(302, "/home")
username := sess.Values["username"]
if username != nil {
return c.Redirect(http.StatusFound, "/home")
}
log.Warn("invalid (or expired) session", "error", err.Error())
}
tpl := getTmpl("signup.tmpl")
@ -274,7 +283,7 @@ func Signup() echo.HandlerFunc {
// }
// c.SetCookie(cookieCSRF)
err = tpl.Execute(c.Response().Writer,
err := tpl.Execute(c.Response().Writer,
page{
AppName: conf.AppName,
AppVer: appver,
@ -312,44 +321,111 @@ func SignupPost(client *ent.Client) echo.HandlerFunc {
var username string
if uname := c.Request().FormValue("username"); uname != "" {
username = uname
var email string
if passwd := c.Request().FormValue("password"); passwd != "" {
ctx := context.WithValue(context.Background(), moduser.CtxKey{}, log)
u, err := moduser.CreateUser(ctx, client, username, passwd)
if err != nil {
// TODO: don't return the error to the user, perhaps based
// on the devel mode.
return echo.NewHTTPError(
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError)+" failed to create schema resources "+err.Error(),
)
}
log.Infof("successfully registered user '%s': %#v", username, u)
} else {
log.Info("user registration: password was not set, returning to /signup")
}
} else {
log.Info("user registration: username was not set, returning to /signup")
return c.Redirect(302, "/signup")
uname := c.Request().FormValue("username")
if uname == "" {
c.Logger().Error("signup: username was not set, returning to /signup")
return c.Redirect(http.StatusSeeOther, "/signup")
}
username = uname
mail := c.Request().FormValue("email")
if mail == "" {
c.Logger().Error("signup: email not set")
return c.Redirect(http.StatusSeeOther, "/signup")
}
email = mail
passwd := c.Request().FormValue("password")
if passwd == "" {
log.Info("signup: password was not set, returning to /signup")
return c.Redirect(http.StatusSeeOther, "/signup")
}
ctx := context.WithValue(context.Background(), moduser.CtxKey{}, log)
exists, err := moduser.Exists(ctx, client, username, email)
if err != nil {
c.Logger().Error("error checking whether user exists", err)
return renderErrorPage(
c.Response().Writer,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError),
err.Error(),
)
}
if exists {
c.Logger().Error("username/email already taken")
return c.Redirect(http.StatusSeeOther, "/signup")
}
u, err := moduser.CreateUser(
ctx,
client,
email,
username,
passwd,
)
if err != nil {
if errors.Is(err, errors.New("username is not unique")) {
c.Logger().Error("username already taken")
// TODO: re-render signup page with a flash message
// stating what went wrong.
return c.Redirect(http.StatusSeeOther, "/signup")
}
// TODO: don't return the error to the user, perhaps based
// on the devel mode.
err = renderErrorPage(
c.Response().Writer,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError)+" - failed to create schema resources",
err.Error(),
)
if err != nil {
c.Logger().Error("error: %q", err)
return err
}
}
log.Infof("successfully registered user '%s'", username)
log.Debug("user details", "id", u.ID, "email", u.Email, "isAdmin", u.IsAdmin)
secure := c.Request().URL.Scheme == "https" //nolint:goconst
cookieSession := &http.Cookie{
Name: "session",
Value: username,
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
Secure: secure,
HttpOnly: true,
}
c.SetCookie(cookieSession)
return c.Redirect(301, "/home")
sess, _ := session.Get(conf.SessionCookieName, c)
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 3600,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteStrictMode,
}
sess.Values["foo"] = "bar"
sess.Values["username"] = username
err = sess.Save(c.Request(), c.Response())
if err != nil {
c.Logger().Error("failed to save session")
err = renderErrorPage(
c.Response().Writer,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError)+" (make sure you've got cookies enabled)",
err.Error(),
)
if err != nil {
return err
}
}
return c.Redirect(http.StatusMovedPermanently, "/home")
}
}
@ -359,65 +435,33 @@ func Home() echo.HandlerFunc {
tpl := getTmpl("home.tmpl")
session, err := c.Cookie("session")
if err != nil {
if err == http.ErrNoCookie {
log.Infof("error no cookie: %q", err)
err = renderErrorPage(
c.Response().Writer,
http.StatusNotFound, http.StatusText(http.StatusNotFound),
err.Error(),
)
if err != nil {
c.Logger().Errorf("error: %q", err)
return err
}
return nil
}
c.Logger().Errorf("error: %q", err)
return echo.NewHTTPError(http.StatusBadRequest, http.StatusText(http.StatusBadRequest))
sess, _ := session.Get(conf.SessionCookieName, c)
if sess == nil {
log.Info("no session, redirecting to /signin", "endpoint", "/home")
return c.Redirect(http.StatusPermanentRedirect, "/signin")
}
if session != nil {
log.Info("got session")
if err := session.Valid(); err != nil {
log.Warn("invalid or expired session?")
c.Logger().Errorf("error: %q", err)
err = renderErrorPage(
c.Response().Writer,
http.StatusNotFound, http.StatusText(http.StatusNotFound),
err.Error(),
)
if err != nil {
c.Logger().Errorf("error: %q", err)
return echo.NewHTTPError(
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError),
)
}
}
username = session.Value
if sess.Values["foo"] != nil {
log.Info("gorilla session", "custom field test", sess.Values["foo"].(string))
}
uname := sess.Values["username"]
if uname == nil {
log.Info("session cookie found but username invalid, redirecting to signin", "endpoint", "/home")
return c.Redirect(http.StatusSeeOther, "/signin")
}
log.Info("gorilla session", "username", sess.Values["username"].(string))
username = sess.Values["username"].(string)
// example denial.
// if _, err := c.Cookie("aha"); err != nil {
// log.Printf("error: %q", err)
// return echo.NewHTTPError(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized))
// }
err = tpl.Execute(c.Response().Writer,
err := tpl.Execute(c.Response().Writer,
page{
AppName: conf.AppName,
AppVer: appver,
@ -454,66 +498,52 @@ func Home() echo.HandlerFunc {
func Logout() echo.HandlerFunc {
return func(c echo.Context) error {
if c.Request().Method == "POST" {
session, err := c.Cookie("session")
if err != nil {
if errors.Is(err, http.ErrNoCookie) {
log.Info("nobody to log out, redirecting to /signin")
switch {
case c.Request().Method == "POST":
sess, _ := session.Get(conf.SessionCookieName, c)
if sess != nil {
log.Infof("max-age before logout: %d", sess.Options.MaxAge)
sess.Options.MaxAge = -1
return c.Redirect(302, "/signin")
if username := sess.Values["username"]; username != nil {
sess.Values["username"] = ""
}
c.Logger().Errorf("error: %q", err)
return err
err := sess.Save(c.Request(), c.Response())
if err != nil {
c.Logger().Error("could not delete session cookie")
}
}
var username string
if err := session.Valid(); err == nil {
username = session.Value
}
return c.Redirect(http.StatusMovedPermanently, "/logout")
log.Infof("logging out user '%s'", username)
case c.Request().Method == "GET":
tpl := getTmpl("logout.tmpl")
secure := c.Request().URL.Scheme == "https" //nolint:goconst
cookieSession := &http.Cookie{
Name: "session",
Value: "",
SameSite: http.SameSiteStrictMode,
MaxAge: -1,
Secure: secure,
HttpOnly: true,
}
c.SetCookie(cookieSession)
}
tpl := getTmpl("logout.tmpl")
err := tpl.Execute(c.Response().Writer,
page{
AppName: conf.AppName,
AppVer: appver,
Title: "Logout",
DevelMode: conf.DevelMode,
Current: "logout",
},
)
if err != nil {
log.Warnf("error: %q", err)
c.Logger().Errorf("error: %q", err)
err = renderErrorPage(
c.Response().Writer,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError),
err.Error(),
err := tpl.Execute(c.Response().Writer,
page{
AppName: conf.AppName,
AppVer: appver,
Title: "Logout",
DevelMode: conf.DevelMode,
Current: "logout",
},
)
if err != nil {
c.Logger().Errorf("error: %q", err)
return err
err = renderErrorPage(
c.Response().Writer,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError),
err.Error(),
)
if err != nil {
c.Logger().Errorf("error: %q", err)
return err
}
}
}

@ -2,6 +2,7 @@ package user
import (
"context"
"errors"
"fmt"
"git.dotya.ml/mirre-mt/pcmt/ent"
@ -10,20 +11,34 @@ import (
)
// CreateUser adds a user entry to the database.
func CreateUser(ctx context.Context, client *ent.Client, username, password string) (*ent.User, error) {
func CreateUser(ctx context.Context, client *ent.Client, email, username, password string) (*ent.User, error) {
log := ctx.Value(CtxKey{}).(*slogging.Logger)
u, err := client.User.
Create().
SetEmail(email).
SetUsername(username).
SetPassword(password).
Save(ctx)
// TODO: saving cleartext password, rework!
if err != nil {
log.Infof("error querying user: %q", err)
if ent.IsConstraintError(err) {
log.Errorf("the username '%s' already exists", username)
return nil, errors.New("username is not unique")
}
return nil, fmt.Errorf("failed creating user: %w", err)
}
log.Infof("user was created: %#v", u)
log.Debug(
fmt.Sprintf(
"user successfully created - id: %s, name: %s, isActive: %t, isAdmin: %t, createdAt: %s, updatedAt: %s",
u.ID, u.Username, u.IsActive, u.IsAdmin, u.CreatedAt, u.UpdatedAt,
),
)
return u, nil
}
@ -45,3 +60,105 @@ func QueryUser(ctx context.Context, client *ent.Client, username string) (*ent.U
return u, nil
}
// Exists determines whether the username OR email in question was previously
// used to register a user.
func Exists(ctx context.Context, client *ent.Client, username, email string) (bool, error) {
log := ctx.Value(CtxKey{}).(*slogging.Logger)
usernameExists, err := UsernameExists(ctx, client, username)
if err != nil {
log.Warn("failed to check whether username is taken", "error", err, "username requested", username)
return false, fmt.Errorf("failed querying user: %w", err)
}
emailExists, err := EmailExists(ctx, client, email)
if err != nil {
log.Warn("failed to check whether user email exists", "error", err, "user email requested", email)
return false, fmt.Errorf("failed querying user: %w", err)
}
switch {
case usernameExists && emailExists:
log.Infof("user exists: both username '%s' and email: '%s' matched", username, email)
return true, nil
case usernameExists:
log.Infof("username '%s' is already taken", username)
return true, nil
case emailExists:
log.Infof("email '%s' is already registered", email)
return true, nil
}
return false, nil
}
// UsernameExists queries the database to check whether the username is
// available or taken, returns a bool and an error, which will be nil unless
// the error is one of IsNot{Found,Singular}.
func UsernameExists(ctx context.Context, client *ent.Client, username string) (bool, error) {
log := ctx.Value(CtxKey{}).(*slogging.Logger)
usr, err := client.User.
Query().
Where(user.Username(username)).
Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return false, nil
} else if ent.IsNotSingular(err) {
log.Errorf("apparently more than one user managed to acquire the username '%s', bailing", username)
return true, err
}
log.Warn("failed to check whether user exists", "error", err, "username queried", username)
return false, fmt.Errorf("failed querying username: %w", err)
}
if usr != nil {
log.Infof("username '%s' found, user id: %s", username, usr.ID)
return true, nil
}
log.Infof("username '%s' not found", username)
return false, nil
}
// EmailExists queries the database to check whether the email was already
// used; returns a bool and an error, which will be nil unless the error is not
// one of IsNot{Found,Singular}.
func EmailExists(ctx context.Context, client *ent.Client, email string) (bool, error) {
log := ctx.Value(CtxKey{}).(*slogging.Logger)
usr, err := client.User.
Query().
Where(user.Email(email)).
Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return false, nil
} else if ent.IsNotSingular(err) {
log.Errorf("apparently more than one user managed to register using the email '%s', bailing", email)
return true, err
}
log.Warn("failed to check whether email exists", "error", err, "email queried", email)
return false, fmt.Errorf("failed querying user email: %w", err)
}
if usr != nil {
log.Infof("user email '%s' found, user id: %s", email, usr.ID)
return true, nil
}
log.Infof("user email '%s' not found", email)
return false, nil
}

108
modules/user/user_test.go Normal file

@ -0,0 +1,108 @@
package user
import (
"context"
"log"
"testing"
"git.dotya.ml/mirre-mt/pcmt/ent"
"git.dotya.ml/mirre-mt/pcmt/slogging"
_ "github.com/xiaoqidun/entps"
)
func TestUserExists(t *testing.T) {
t.Parallel()
db := supplyDBClient()
if db == nil {
t.Error("could not connect to db")
}
defer db.Close()
username := "dude"
email := "dude@b.cc"
ctx := getCtx()
usernameFound, err := UsernameExists(ctx, db, username)
if err != nil {
t.Errorf("error checking for username {%s} existence: %q",
username,
err,
)
}
if usernameFound {
t.Errorf("unexpected: user{%s} should not have been found",
username,
)
}
if _, err := EmailExists(ctx, db, email); err != nil {
t.Errorf("unexpected: user email '%s' should not have been found",
email,
)
}
usr, err := CreateUser(ctx, db, email, username, "so strong")
if err != nil {
t.Fatalf("failed to create user, error: %q", err)
} else if usr == nil {
t.Fatal("got nil usr back")
}
if usr.Username != username {
t.Errorf("got back wrong username, want: %s, got: %s", username, usr.Username)
}
usernameFound, err = UsernameExists(ctx, db, username)
if err != nil {
t.Errorf("error checking for username {%s} existence: %q",
username,
err,
)
}
if !usernameFound {
t.Errorf("unexpected: user{%s} should not have been found",
username,
)
}
exists, err := Exists(ctx, db, username, email)
if err != nil {
t.Errorf("error checking whether user exists: %q", err)
}
if !exists {
t.Errorf("unexpected: user{%s} does not exists and they should", username)
}
}
func supplyDBClient() *ent.Client {
connstr := "file:ent_tests?mode=memory&cache=shared&_fk=1"
db, err := ent.Open("sqlite3", connstr)
if err != nil {
log.Printf("failed to open a connection to sqlite %q\n", err)
return nil
}
if err = db.Schema.Create(context.Background()); err != nil {
log.Printf("failed creating schema resources: %v", err)
return nil
}
return db
}
func getCtx() context.Context {
l := slogging.Init(false)
ctx := context.WithValue(
context.Background(),
CtxKey{},
l,
)
return ctx
}