Compare commits

...

1 Commits

Author SHA1 Message Date
Matthew Kilgore
dd873a95da Add default group handling and user-groups relationship
- Introduced `default_group_id` field in the User model to manage user group defaults.
- Updated user creation and update logic to utilize the new default group ID.
- Implemented a many-to-many relationship between users and groups via a new `user_groups` junction table.
- Refactored relevant queries and middleware to support tenant-based access using the default group.
2025-12-26 20:16:52 -05:00
29 changed files with 835 additions and 243 deletions

View File

@@ -55,7 +55,7 @@ func (a *app) SetupDemo() error {
return errors.New("failed to setup demo") return errors.New("failed to setup demo")
} }
_, err = a.services.Items.CsvImport(ctx, self.GroupID, strings.NewReader(csvText)) _, err = a.services.Items.CsvImport(ctx, self.DefaultGroupID, strings.NewReader(csvText))
if err != nil { if err != nil {
log.Err(err).Msg("Failed to import CSV") log.Err(err).Msg("Failed to import CSV")
return errors.New("failed to setup demo") return errors.New("failed to setup demo")

View File

@@ -339,7 +339,7 @@ func (ctrl *V1Controller) HandleItemsImport() errchain.HandlerFunc {
user := services.UseUserCtx(r.Context()) user := services.UseUserCtx(r.Context())
_, err = ctrl.svc.Items.CsvImport(r.Context(), user.GroupID, file) _, err = ctrl.svc.Items.CsvImport(r.Context(), user.DefaultGroupID, file)
if err != nil { if err != nil {
log.Err(err).Msg("failed to import items") log.Err(err).Msg("failed to import items")
return validate.NewRequestError(err, http.StatusInternalServerError) return validate.NewRequestError(err, http.StatusInternalServerError)

View File

@@ -1,9 +1,10 @@
package v1 package v1
import ( import (
"net/http"
"github.com/hay-kot/httpkit/errchain" "github.com/hay-kot/httpkit/errchain"
"github.com/sysadminsmedia/homebox/backend/internal/core/services" "github.com/sysadminsmedia/homebox/backend/internal/core/services"
"net/http"
) )
// HandleBillOfMaterialsExport godoc // HandleBillOfMaterialsExport godoc
@@ -18,7 +19,7 @@ func (ctrl *V1Controller) HandleBillOfMaterialsExport() errchain.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error {
actor := services.UseUserCtx(r.Context()) actor := services.UseUserCtx(r.Context())
csv, err := ctrl.svc.Items.ExportBillOfMaterialsCSV(r.Context(), actor.GroupID) csv, err := ctrl.svc.Items.ExportBillOfMaterialsCSV(r.Context(), actor.DefaultGroupID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/google/uuid"
"github.com/hay-kot/httpkit/errchain" "github.com/hay-kot/httpkit/errchain"
v1 "github.com/sysadminsmedia/homebox/backend/app/api/handlers/v1" v1 "github.com/sysadminsmedia/homebox/backend/app/api/handlers/v1"
"github.com/sysadminsmedia/homebox/backend/internal/core/services" "github.com/sysadminsmedia/homebox/backend/internal/core/services"
@@ -152,3 +153,48 @@ func (a *app) mwAuthToken(next errchain.Handler) errchain.Handler {
return next.ServeHTTP(w, r) return next.ServeHTTP(w, r)
}) })
} }
// mwTenant is a middleware that will parse the X-Tenant header and validate the user has access
// to the requested tenant. If no header is provided, the user's default group is used.
//
// WARNING: This middleware _MUST_ be called after mwAuthToken
func (a *app) mwTenant(next errchain.Handler) errchain.Handler {
return errchain.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
// Get the user from context (set by mwAuthToken)
user := services.UseUserCtx(ctx)
if user == nil {
return validate.NewRequestError(errors.New("user context not found"), http.StatusInternalServerError)
}
tenantID := user.DefaultGroupID
// Check for X-Tenant header
if tenantHeader := r.Header.Get("X-Tenant"); tenantHeader != "" {
parsedTenantID, err := uuid.Parse(tenantHeader)
if err != nil {
return validate.NewRequestError(errors.New("invalid X-Tenant header format"), http.StatusBadRequest)
}
// Validate user has access to the requested tenant
hasAccess := false
for _, gid := range user.GroupIDs {
if gid == parsedTenantID {
hasAccess = true
break
}
}
if !hasAccess {
return validate.NewRequestError(errors.New("user does not have access to the requested tenant"), http.StatusForbidden)
}
tenantID = parsedTenantID
}
// Set the tenant in context
r = r.WithContext(services.SetTenantCtx(ctx, tenantID))
return next.ServeHTTP(w, r)
})
}

View File

@@ -82,6 +82,7 @@ func (a *app) mountRoutes(r *chi.Mux, chain *errchain.ErrChain, repos *repo.AllR
userMW := []errchain.Middleware{ userMW := []errchain.Middleware{
a.mwAuthToken, a.mwAuthToken,
a.mwTenant,
a.mwRoles(RoleModeOr, authroles.RoleUser.String()), a.mwRoles(RoleModeOr, authroles.RoleUser.String()),
} }

View File

@@ -350,6 +350,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
@@ -374,6 +376,8 @@ github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOF
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olahol/melody v1.4.0 h1:Pa5SdeZL/zXPi1tJuMAPDbl4n3gQOThSL6G1p4qZ4SI= github.com/olahol/melody v1.4.0 h1:Pa5SdeZL/zXPi1tJuMAPDbl4n3gQOThSL6G1p4qZ4SI=
github.com/olahol/melody v1.4.0/go.mod h1:GgkTl6Y7yWj/HtfD48Q5vLKPVoZOH+Qqgfa7CvJgJM4= github.com/olahol/melody v1.4.0/go.mod h1:GgkTl6Y7yWj/HtfD48Q5vLKPVoZOH+Qqgfa7CvJgJM4=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/onsi/ginkgo/v2 v2.9.2 h1:BA2GMJOtfGAfagzYtrAlufIP0lq6QERkFmHLMLPwFSU= github.com/onsi/ginkgo/v2 v2.9.2 h1:BA2GMJOtfGAfagzYtrAlufIP0lq6QERkFmHLMLPwFSU=
github.com/onsi/ginkgo/v2 v2.9.2/go.mod h1:WHcJJG2dIlcCqVfBAwUCrJxSPFb6v4azBwgxeMeDuts= github.com/onsi/ginkgo/v2 v2.9.2/go.mod h1:WHcJJG2dIlcCqVfBAwUCrJxSPFb6v4azBwgxeMeDuts=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
@@ -418,6 +422,10 @@ github.com/shirou/gopsutil/v4 v4.25.11 h1:X53gB7muL9Gnwwo2evPSE+SfOrltMoR6V3xJAX
github.com/shirou/gopsutil/v4 v4.25.11/go.mod h1:EivAfP5x2EhLp2ovdpKSozecVXn1TmuG7SMzs/Wh4PU= github.com/shirou/gopsutil/v4 v4.25.11/go.mod h1:EivAfP5x2EhLp2ovdpKSozecVXn1TmuG7SMzs/Wh4PU=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo=
github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View File

@@ -14,6 +14,7 @@ type contextKeys struct {
var ( var (
ContextUser = &contextKeys{name: "User"} ContextUser = &contextKeys{name: "User"}
ContextUserToken = &contextKeys{name: "UserToken"} ContextUserToken = &contextKeys{name: "UserToken"}
ContextTenant = &contextKeys{name: "Tenant"}
) )
type Context struct { type Context struct {
@@ -33,10 +34,14 @@ type Context struct {
// This extracts the users from the context and embeds it into the ServiceContext struct // This extracts the users from the context and embeds it into the ServiceContext struct
func NewContext(ctx context.Context) Context { func NewContext(ctx context.Context) Context {
user := UseUserCtx(ctx) user := UseUserCtx(ctx)
gid := UseTenantCtx(ctx)
if gid == uuid.Nil && user != nil {
gid = user.DefaultGroupID
}
return Context{ return Context{
Context: ctx, Context: ctx,
UID: user.ID, UID: user.ID,
GID: user.GroupID, GID: gid,
User: user, User: user,
} }
} }
@@ -64,3 +69,17 @@ func UseTokenCtx(ctx context.Context) string {
} }
return "" return ""
} }
// UseTenantCtx is a helper function that returns the tenant group ID from the context.
// Returns uuid.Nil if not set.
func UseTenantCtx(ctx context.Context) uuid.UUID {
if val := ctx.Value(ContextTenant); val != nil {
return val.(uuid.UUID)
}
return uuid.Nil
}
// SetTenantCtx is a helper function that sets the ContextTenant in the context.
func SetTenantCtx(ctx context.Context, tenantID uuid.UUID) context.Context {
return context.WithValue(ctx, ContextTenant, tenantID)
}

View File

@@ -14,7 +14,7 @@ type GroupService struct {
func (svc *GroupService) UpdateGroup(ctx Context, data repo.GroupUpdate) (repo.Group, error) { func (svc *GroupService) UpdateGroup(ctx Context, data repo.GroupUpdate) (repo.Group, error) {
if data.Name == "" { if data.Name == "" {
data.Name = ctx.User.GroupName return repo.Group{}, errors.New("group name cannot be empty")
} }
if data.Currency == "" { if data.Currency == "" {

View File

@@ -81,12 +81,12 @@ func (svc *UserService) RegisterUser(ctx context.Context, data UserRegistration)
hashed, _ := hasher.HashPassword(data.Password) hashed, _ := hasher.HashPassword(data.Password)
usrCreate := repo.UserCreate{ usrCreate := repo.UserCreate{
Name: data.Name, Name: data.Name,
Email: data.Email, Email: data.Email,
Password: &hashed, Password: &hashed,
IsSuperuser: false, IsSuperuser: false,
GroupID: group.ID, DefaultGroupID: group.ID,
IsOwner: creatingGroup, IsOwner: creatingGroup,
} }
usr, err := svc.repos.Users.Create(ctx, usrCreate) usr, err := svc.repos.Users.Create(ctx, usrCreate)
@@ -99,7 +99,7 @@ func (svc *UserService) RegisterUser(ctx context.Context, data UserRegistration)
if creatingGroup { if creatingGroup {
log.Debug().Msg("creating default labels") log.Debug().Msg("creating default labels")
for _, label := range defaultLabels() { for _, label := range defaultLabels() {
_, err := svc.repos.Labels.Create(ctx, usr.GroupID, label) _, err := svc.repos.Labels.Create(ctx, usr.DefaultGroupID, label)
if err != nil { if err != nil {
return repo.UserOut{}, err return repo.UserOut{}, err
} }
@@ -107,7 +107,7 @@ func (svc *UserService) RegisterUser(ctx context.Context, data UserRegistration)
log.Debug().Msg("creating default locations") log.Debug().Msg("creating default locations")
for _, location := range defaultLocations() { for _, location := range defaultLocations() {
_, err := svc.repos.Locations.Create(ctx, usr.GroupID, location) _, err := svc.repos.Locations.Create(ctx, usr.DefaultGroupID, location)
if err != nil { if err != nil {
return repo.UserOut{}, err return repo.UserOut{}, err
} }
@@ -287,12 +287,12 @@ func (svc *UserService) registerOIDCUser(ctx context.Context, issuer, subject, e
} }
usrCreate := repo.UserCreate{ usrCreate := repo.UserCreate{
Name: name, Name: name,
Email: email, Email: email,
Password: nil, Password: nil,
IsSuperuser: false, IsSuperuser: false,
GroupID: group.ID, DefaultGroupID: group.ID,
IsOwner: true, IsOwner: true,
} }
entUser, err := svc.repos.Users.CreateWithOIDC(ctx, usrCreate, issuer, subject) entUser, err := svc.repos.Users.CreateWithOIDC(ctx, usrCreate, issuer, subject)

View File

@@ -2711,15 +2711,15 @@ func (c *UserClient) GetX(ctx context.Context, id uuid.UUID) *User {
return obj return obj
} }
// QueryGroup queries the group edge of a User. // QueryGroups queries the groups edge of a User.
func (c *UserClient) QueryGroup(_m *User) *GroupQuery { func (c *UserClient) QueryGroups(_m *User) *GroupQuery {
query := (&GroupClient{config: c.config}).Query() query := (&GroupClient{config: c.config}).Query()
query.path = func(context.Context) (fromV *sql.Selector, _ error) { query.path = func(context.Context) (fromV *sql.Selector, _ error) {
id := _m.ID id := _m.ID
step := sqlgraph.NewStep( step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, id), sqlgraph.From(user.Table, user.FieldID, id),
sqlgraph.To(group.Table, group.FieldID), sqlgraph.To(group.Table, group.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, user.GroupTable, user.GroupColumn), sqlgraph.Edge(sqlgraph.O2M, false, user.GroupsTable, user.GroupsColumn),
) )
fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
return fromV, nil return fromV, nil

View File

@@ -29,6 +29,7 @@ type Group struct {
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set. // The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"` Edges GroupEdges `json:"edges"`
user_groups *uuid.UUID
selectValues sql.SelectValues selectValues sql.SelectValues
} }
@@ -127,6 +128,8 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
case group.FieldID: case group.FieldID:
values[i] = new(uuid.UUID) values[i] = new(uuid.UUID)
case group.ForeignKeys[0]: // user_groups
values[i] = &sql.NullScanner{S: new(uuid.UUID)}
default: default:
values[i] = new(sql.UnknownType) values[i] = new(sql.UnknownType)
} }
@@ -172,6 +175,13 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid { } else if value.Valid {
_m.Currency = value.String _m.Currency = value.String
} }
case group.ForeignKeys[0]:
if value, ok := values[i].(*sql.NullScanner); !ok {
return fmt.Errorf("unexpected type %T for field user_groups", values[i])
} else if value.Valid {
_m.user_groups = new(uuid.UUID)
*_m.user_groups = *value.S.(*uuid.UUID)
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }

View File

@@ -99,6 +99,12 @@ var Columns = []string{
FieldCurrency, FieldCurrency,
} }
// ForeignKeys holds the SQL foreign-keys that are owned by the "groups"
// table and are not defined as standalone fields in the schema.
var ForeignKeys = []string{
"user_groups",
}
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool { func ValidColumn(column string) bool {
for i := range Columns { for i := range Columns {
@@ -106,6 +112,11 @@ func ValidColumn(column string) bool {
return true return true
} }
} }
for i := range ForeignKeys {
if column == ForeignKeys[i] {
return true
}
}
return false return false
} }

View File

@@ -38,6 +38,7 @@ type GroupQuery struct {
withInvitationTokens *GroupInvitationTokenQuery withInvitationTokens *GroupInvitationTokenQuery
withNotifiers *NotifierQuery withNotifiers *NotifierQuery
withItemTemplates *ItemTemplateQuery withItemTemplates *ItemTemplateQuery
withFKs bool
// intermediate query (i.e. traversal path). // intermediate query (i.e. traversal path).
sql *sql.Selector sql *sql.Selector
path func(context.Context) (*sql.Selector, error) path func(context.Context) (*sql.Selector, error)
@@ -587,6 +588,7 @@ func (_q *GroupQuery) prepareQuery(ctx context.Context) error {
func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) { func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group, error) {
var ( var (
nodes = []*Group{} nodes = []*Group{}
withFKs = _q.withFKs
_spec = _q.querySpec() _spec = _q.querySpec()
loadedTypes = [7]bool{ loadedTypes = [7]bool{
_q.withUsers != nil, _q.withUsers != nil,
@@ -598,6 +600,9 @@ func (_q *GroupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Group,
_q.withItemTemplates != nil, _q.withItemTemplates != nil,
} }
) )
if withFKs {
_spec.Node.Columns = append(_spec.Node.Columns, group.ForeignKeys...)
}
_spec.ScanValues = func(columns []string) ([]any, error) { _spec.ScanValues = func(columns []string) ([]any, error) {
return (*Group).scanValues(nil, columns) return (*Group).scanValues(nil, columns)
} }

View File

@@ -98,12 +98,21 @@ var (
{Name: "updated_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime},
{Name: "name", Type: field.TypeString, Size: 255}, {Name: "name", Type: field.TypeString, Size: 255},
{Name: "currency", Type: field.TypeString, Default: "usd"}, {Name: "currency", Type: field.TypeString, Default: "usd"},
{Name: "user_groups", Type: field.TypeUUID, Nullable: true},
} }
// GroupsTable holds the schema information for the "groups" table. // GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{ GroupsTable = &schema.Table{
Name: "groups", Name: "groups",
Columns: GroupsColumns, Columns: GroupsColumns,
PrimaryKey: []*schema.Column{GroupsColumns[0]}, PrimaryKey: []*schema.Column{GroupsColumns[0]},
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "groups_users_groups",
Columns: []*schema.Column{GroupsColumns[5]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.SetNull,
},
},
} }
// GroupInvitationTokensColumns holds the columns for the "group_invitation_tokens" table. // GroupInvitationTokensColumns holds the columns for the "group_invitation_tokens" table.
GroupInvitationTokensColumns = []*schema.Column{ GroupInvitationTokensColumns = []*schema.Column{
@@ -468,7 +477,8 @@ var (
{Name: "activated_on", Type: field.TypeTime, Nullable: true}, {Name: "activated_on", Type: field.TypeTime, Nullable: true},
{Name: "oidc_issuer", Type: field.TypeString, Nullable: true}, {Name: "oidc_issuer", Type: field.TypeString, Nullable: true},
{Name: "oidc_subject", Type: field.TypeString, Nullable: true}, {Name: "oidc_subject", Type: field.TypeString, Nullable: true},
{Name: "group_users", Type: field.TypeUUID}, {Name: "default_group_id", Type: field.TypeUUID, Nullable: true},
{Name: "group_users", Type: field.TypeUUID, Nullable: true},
} }
// UsersTable holds the schema information for the "users" table. // UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{ UsersTable = &schema.Table{
@@ -478,9 +488,9 @@ var (
ForeignKeys: []*schema.ForeignKey{ ForeignKeys: []*schema.ForeignKey{
{ {
Symbol: "users_groups_users", Symbol: "users_groups_users",
Columns: []*schema.Column{UsersColumns[12]}, Columns: []*schema.Column{UsersColumns[13]},
RefColumns: []*schema.Column{GroupsColumns[0]}, RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.Cascade, OnDelete: schema.SetNull,
}, },
}, },
Indexes: []*schema.Index{ Indexes: []*schema.Index{
@@ -541,6 +551,7 @@ func init() {
AttachmentsTable.ForeignKeys[1].RefTable = ItemsTable AttachmentsTable.ForeignKeys[1].RefTable = ItemsTable
AuthRolesTable.ForeignKeys[0].RefTable = AuthTokensTable AuthRolesTable.ForeignKeys[0].RefTable = AuthTokensTable
AuthTokensTable.ForeignKeys[0].RefTable = UsersTable AuthTokensTable.ForeignKeys[0].RefTable = UsersTable
GroupsTable.ForeignKeys[0].RefTable = UsersTable
GroupInvitationTokensTable.ForeignKeys[0].RefTable = GroupsTable GroupInvitationTokensTable.ForeignKeys[0].RefTable = GroupsTable
ItemsTable.ForeignKeys[0].RefTable = GroupsTable ItemsTable.ForeignKeys[0].RefTable = GroupsTable
ItemsTable.ForeignKeys[1].RefTable = ItemsTable ItemsTable.ForeignKeys[1].RefTable = ItemsTable

View File

@@ -12583,9 +12583,11 @@ type UserMutation struct {
activated_on *time.Time activated_on *time.Time
oidc_issuer *string oidc_issuer *string
oidc_subject *string oidc_subject *string
default_group_id *uuid.UUID
clearedFields map[string]struct{} clearedFields map[string]struct{}
group *uuid.UUID groups map[uuid.UUID]struct{}
clearedgroup bool removedgroups map[uuid.UUID]struct{}
clearedgroups bool
auth_tokens map[uuid.UUID]struct{} auth_tokens map[uuid.UUID]struct{}
removedauth_tokens map[uuid.UUID]struct{} removedauth_tokens map[uuid.UUID]struct{}
clearedauth_tokens bool clearedauth_tokens bool
@@ -13149,43 +13151,107 @@ func (m *UserMutation) ResetOidcSubject() {
delete(m.clearedFields, user.FieldOidcSubject) delete(m.clearedFields, user.FieldOidcSubject)
} }
// SetGroupID sets the "group" edge to the Group entity by id. // SetDefaultGroupID sets the "default_group_id" field.
func (m *UserMutation) SetGroupID(id uuid.UUID) { func (m *UserMutation) SetDefaultGroupID(u uuid.UUID) {
m.group = &id m.default_group_id = &u
} }
// ClearGroup clears the "group" edge to the Group entity. // DefaultGroupID returns the value of the "default_group_id" field in the mutation.
func (m *UserMutation) ClearGroup() { func (m *UserMutation) DefaultGroupID() (r uuid.UUID, exists bool) {
m.clearedgroup = true v := m.default_group_id
if v == nil {
return
}
return *v, true
} }
// GroupCleared reports if the "group" edge to the Group entity was cleared. // OldDefaultGroupID returns the old "default_group_id" field's value of the User entity.
func (m *UserMutation) GroupCleared() bool { // If the User object wasn't provided to the builder, the object is fetched from the database.
return m.clearedgroup // An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldDefaultGroupID(ctx context.Context) (v *uuid.UUID, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldDefaultGroupID is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldDefaultGroupID requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldDefaultGroupID: %w", err)
}
return oldValue.DefaultGroupID, nil
} }
// GroupID returns the "group" edge ID in the mutation. // ClearDefaultGroupID clears the value of the "default_group_id" field.
func (m *UserMutation) GroupID() (id uuid.UUID, exists bool) { func (m *UserMutation) ClearDefaultGroupID() {
if m.group != nil { m.default_group_id = nil
return *m.group, true m.clearedFields[user.FieldDefaultGroupID] = struct{}{}
}
// DefaultGroupIDCleared returns if the "default_group_id" field was cleared in this mutation.
func (m *UserMutation) DefaultGroupIDCleared() bool {
_, ok := m.clearedFields[user.FieldDefaultGroupID]
return ok
}
// ResetDefaultGroupID resets all changes to the "default_group_id" field.
func (m *UserMutation) ResetDefaultGroupID() {
m.default_group_id = nil
delete(m.clearedFields, user.FieldDefaultGroupID)
}
// AddGroupIDs adds the "groups" edge to the Group entity by ids.
func (m *UserMutation) AddGroupIDs(ids ...uuid.UUID) {
if m.groups == nil {
m.groups = make(map[uuid.UUID]struct{})
}
for i := range ids {
m.groups[ids[i]] = struct{}{}
}
}
// ClearGroups clears the "groups" edge to the Group entity.
func (m *UserMutation) ClearGroups() {
m.clearedgroups = true
}
// GroupsCleared reports if the "groups" edge to the Group entity was cleared.
func (m *UserMutation) GroupsCleared() bool {
return m.clearedgroups
}
// RemoveGroupIDs removes the "groups" edge to the Group entity by IDs.
func (m *UserMutation) RemoveGroupIDs(ids ...uuid.UUID) {
if m.removedgroups == nil {
m.removedgroups = make(map[uuid.UUID]struct{})
}
for i := range ids {
delete(m.groups, ids[i])
m.removedgroups[ids[i]] = struct{}{}
}
}
// RemovedGroups returns the removed IDs of the "groups" edge to the Group entity.
func (m *UserMutation) RemovedGroupsIDs() (ids []uuid.UUID) {
for id := range m.removedgroups {
ids = append(ids, id)
} }
return return
} }
// GroupIDs returns the "group" edge IDs in the mutation. // GroupsIDs returns the "groups" edge IDs in the mutation.
// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use func (m *UserMutation) GroupsIDs() (ids []uuid.UUID) {
// GroupID instead. It exists only for internal usage by the builders. for id := range m.groups {
func (m *UserMutation) GroupIDs() (ids []uuid.UUID) { ids = append(ids, id)
if id := m.group; id != nil {
ids = append(ids, *id)
} }
return return
} }
// ResetGroup resets all changes to the "group" edge. // ResetGroups resets all changes to the "groups" edge.
func (m *UserMutation) ResetGroup() { func (m *UserMutation) ResetGroups() {
m.group = nil m.groups = nil
m.clearedgroup = false m.clearedgroups = false
m.removedgroups = nil
} }
// AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by ids. // AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by ids.
@@ -13330,7 +13396,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UserMutation) Fields() []string { func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 11) fields := make([]string, 0, 12)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt) fields = append(fields, user.FieldCreatedAt)
} }
@@ -13364,6 +13430,9 @@ func (m *UserMutation) Fields() []string {
if m.oidc_subject != nil { if m.oidc_subject != nil {
fields = append(fields, user.FieldOidcSubject) fields = append(fields, user.FieldOidcSubject)
} }
if m.default_group_id != nil {
fields = append(fields, user.FieldDefaultGroupID)
}
return fields return fields
} }
@@ -13394,6 +13463,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.OidcIssuer() return m.OidcIssuer()
case user.FieldOidcSubject: case user.FieldOidcSubject:
return m.OidcSubject() return m.OidcSubject()
case user.FieldDefaultGroupID:
return m.DefaultGroupID()
} }
return nil, false return nil, false
} }
@@ -13425,6 +13496,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldOidcIssuer(ctx) return m.OldOidcIssuer(ctx)
case user.FieldOidcSubject: case user.FieldOidcSubject:
return m.OldOidcSubject(ctx) return m.OldOidcSubject(ctx)
case user.FieldDefaultGroupID:
return m.OldDefaultGroupID(ctx)
} }
return nil, fmt.Errorf("unknown User field %s", name) return nil, fmt.Errorf("unknown User field %s", name)
} }
@@ -13511,6 +13584,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
} }
m.SetOidcSubject(v) m.SetOidcSubject(v)
return nil return nil
case user.FieldDefaultGroupID:
v, ok := value.(uuid.UUID)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetDefaultGroupID(v)
return nil
} }
return fmt.Errorf("unknown User field %s", name) return fmt.Errorf("unknown User field %s", name)
} }
@@ -13553,6 +13633,9 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldOidcSubject) { if m.FieldCleared(user.FieldOidcSubject) {
fields = append(fields, user.FieldOidcSubject) fields = append(fields, user.FieldOidcSubject)
} }
if m.FieldCleared(user.FieldDefaultGroupID) {
fields = append(fields, user.FieldDefaultGroupID)
}
return fields return fields
} }
@@ -13579,6 +13662,9 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldOidcSubject: case user.FieldOidcSubject:
m.ClearOidcSubject() m.ClearOidcSubject()
return nil return nil
case user.FieldDefaultGroupID:
m.ClearDefaultGroupID()
return nil
} }
return fmt.Errorf("unknown User nullable field %s", name) return fmt.Errorf("unknown User nullable field %s", name)
} }
@@ -13620,6 +13706,9 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldOidcSubject: case user.FieldOidcSubject:
m.ResetOidcSubject() m.ResetOidcSubject()
return nil return nil
case user.FieldDefaultGroupID:
m.ResetDefaultGroupID()
return nil
} }
return fmt.Errorf("unknown User field %s", name) return fmt.Errorf("unknown User field %s", name)
} }
@@ -13627,8 +13716,8 @@ func (m *UserMutation) ResetField(name string) error {
// AddedEdges returns all edge names that were set/added in this mutation. // AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string { func (m *UserMutation) AddedEdges() []string {
edges := make([]string, 0, 3) edges := make([]string, 0, 3)
if m.group != nil { if m.groups != nil {
edges = append(edges, user.EdgeGroup) edges = append(edges, user.EdgeGroups)
} }
if m.auth_tokens != nil { if m.auth_tokens != nil {
edges = append(edges, user.EdgeAuthTokens) edges = append(edges, user.EdgeAuthTokens)
@@ -13643,10 +13732,12 @@ func (m *UserMutation) AddedEdges() []string {
// name in this mutation. // name in this mutation.
func (m *UserMutation) AddedIDs(name string) []ent.Value { func (m *UserMutation) AddedIDs(name string) []ent.Value {
switch name { switch name {
case user.EdgeGroup: case user.EdgeGroups:
if id := m.group; id != nil { ids := make([]ent.Value, 0, len(m.groups))
return []ent.Value{*id} for id := range m.groups {
ids = append(ids, id)
} }
return ids
case user.EdgeAuthTokens: case user.EdgeAuthTokens:
ids := make([]ent.Value, 0, len(m.auth_tokens)) ids := make([]ent.Value, 0, len(m.auth_tokens))
for id := range m.auth_tokens { for id := range m.auth_tokens {
@@ -13666,6 +13757,9 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
// RemovedEdges returns all edge names that were removed in this mutation. // RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string { func (m *UserMutation) RemovedEdges() []string {
edges := make([]string, 0, 3) edges := make([]string, 0, 3)
if m.removedgroups != nil {
edges = append(edges, user.EdgeGroups)
}
if m.removedauth_tokens != nil { if m.removedauth_tokens != nil {
edges = append(edges, user.EdgeAuthTokens) edges = append(edges, user.EdgeAuthTokens)
} }
@@ -13679,6 +13773,12 @@ func (m *UserMutation) RemovedEdges() []string {
// the given name in this mutation. // the given name in this mutation.
func (m *UserMutation) RemovedIDs(name string) []ent.Value { func (m *UserMutation) RemovedIDs(name string) []ent.Value {
switch name { switch name {
case user.EdgeGroups:
ids := make([]ent.Value, 0, len(m.removedgroups))
for id := range m.removedgroups {
ids = append(ids, id)
}
return ids
case user.EdgeAuthTokens: case user.EdgeAuthTokens:
ids := make([]ent.Value, 0, len(m.removedauth_tokens)) ids := make([]ent.Value, 0, len(m.removedauth_tokens))
for id := range m.removedauth_tokens { for id := range m.removedauth_tokens {
@@ -13698,8 +13798,8 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
// ClearedEdges returns all edge names that were cleared in this mutation. // ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string { func (m *UserMutation) ClearedEdges() []string {
edges := make([]string, 0, 3) edges := make([]string, 0, 3)
if m.clearedgroup { if m.clearedgroups {
edges = append(edges, user.EdgeGroup) edges = append(edges, user.EdgeGroups)
} }
if m.clearedauth_tokens { if m.clearedauth_tokens {
edges = append(edges, user.EdgeAuthTokens) edges = append(edges, user.EdgeAuthTokens)
@@ -13714,8 +13814,8 @@ func (m *UserMutation) ClearedEdges() []string {
// was cleared in this mutation. // was cleared in this mutation.
func (m *UserMutation) EdgeCleared(name string) bool { func (m *UserMutation) EdgeCleared(name string) bool {
switch name { switch name {
case user.EdgeGroup: case user.EdgeGroups:
return m.clearedgroup return m.clearedgroups
case user.EdgeAuthTokens: case user.EdgeAuthTokens:
return m.clearedauth_tokens return m.clearedauth_tokens
case user.EdgeNotifiers: case user.EdgeNotifiers:
@@ -13728,9 +13828,6 @@ func (m *UserMutation) EdgeCleared(name string) bool {
// if that edge is not defined in the schema. // if that edge is not defined in the schema.
func (m *UserMutation) ClearEdge(name string) error { func (m *UserMutation) ClearEdge(name string) error {
switch name { switch name {
case user.EdgeGroup:
m.ClearGroup()
return nil
} }
return fmt.Errorf("unknown User unique edge %s", name) return fmt.Errorf("unknown User unique edge %s", name)
} }
@@ -13739,8 +13836,8 @@ func (m *UserMutation) ClearEdge(name string) error {
// It returns an error if the edge is not defined in the schema. // It returns an error if the edge is not defined in the schema.
func (m *UserMutation) ResetEdge(name string) error { func (m *UserMutation) ResetEdge(name string) error {
switch name { switch name {
case user.EdgeGroup: case user.EdgeGroups:
m.ResetGroup() m.ResetGroups()
return nil return nil
case user.EdgeAuthTokens: case user.EdgeAuthTokens:
m.ResetAuthTokens() m.ResetAuthTokens()

View File

@@ -42,7 +42,7 @@ func (Group) Edges() []ent.Edge {
} }
return []ent.Edge{ return []ent.Edge{
owned("users", User.Type), edge.To("users", User.Type),
owned("locations", Location.Type), owned("locations", Location.Type),
owned("items", Item.Type), owned("items", Item.Type),
owned("labels", Label.Type), owned("labels", Label.Type),
@@ -72,14 +72,14 @@ func (g GroupMixin) Fields() []ent.Field {
} }
func (g GroupMixin) Edges() []ent.Edge { func (g GroupMixin) Edges() []ent.Edge {
edge := edge.From("group", Group.Type). e := edge.From("group", Group.Type).
Ref(g.ref). Ref(g.ref).
Unique(). Unique().
Required() Required()
if g.field != "" { if g.field != "" {
edge = edge.Field(g.field) e = e.Field(g.field)
} }
return []ent.Edge{edge} return []ent.Edge{e}
} }

View File

@@ -19,7 +19,6 @@ type User struct {
func (User) Mixin() []ent.Mixin { func (User) Mixin() []ent.Mixin {
return []ent.Mixin{ return []ent.Mixin{
mixins.BaseMixin{}, mixins.BaseMixin{},
GroupMixin{ref: "users"},
} }
} }
@@ -54,6 +53,10 @@ func (User) Fields() []ent.Field {
field.String("oidc_subject"). field.String("oidc_subject").
Optional(). Optional().
Nillable(), Nillable(),
// default_group_id is the user's primary tenant/group
field.UUID("default_group_id", uuid.UUID{}).
Optional().
Nillable(),
} }
} }
@@ -66,6 +69,7 @@ func (User) Indexes() []ent.Index {
// Edges of the User. // Edges of the User.
func (User) Edges() []ent.Edge { func (User) Edges() []ent.Edge {
return []ent.Edge{ return []ent.Edge{
edge.To("groups", Group.Type),
edge.To("auth_tokens", AuthTokens.Type). edge.To("auth_tokens", AuthTokens.Type).
Annotations(entsql.Annotation{ Annotations(entsql.Annotation{
OnDelete: entsql.Cascade, OnDelete: entsql.Cascade,

View File

@@ -10,7 +10,6 @@ import (
"entgo.io/ent" "entgo.io/ent"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sysadminsmedia/homebox/backend/internal/data/ent/group"
"github.com/sysadminsmedia/homebox/backend/internal/data/ent/user" "github.com/sysadminsmedia/homebox/backend/internal/data/ent/user"
) )
@@ -41,6 +40,8 @@ type User struct {
OidcIssuer *string `json:"oidc_issuer,omitempty"` OidcIssuer *string `json:"oidc_issuer,omitempty"`
// OidcSubject holds the value of the "oidc_subject" field. // OidcSubject holds the value of the "oidc_subject" field.
OidcSubject *string `json:"oidc_subject,omitempty"` OidcSubject *string `json:"oidc_subject,omitempty"`
// DefaultGroupID holds the value of the "default_group_id" field.
DefaultGroupID *uuid.UUID `json:"default_group_id,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set. // The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"` Edges UserEdges `json:"edges"`
@@ -50,8 +51,8 @@ type User struct {
// UserEdges holds the relations/edges for other nodes in the graph. // UserEdges holds the relations/edges for other nodes in the graph.
type UserEdges struct { type UserEdges struct {
// Group holds the value of the group edge. // Groups holds the value of the groups edge.
Group *Group `json:"group,omitempty"` Groups []*Group `json:"groups,omitempty"`
// AuthTokens holds the value of the auth_tokens edge. // AuthTokens holds the value of the auth_tokens edge.
AuthTokens []*AuthTokens `json:"auth_tokens,omitempty"` AuthTokens []*AuthTokens `json:"auth_tokens,omitempty"`
// Notifiers holds the value of the notifiers edge. // Notifiers holds the value of the notifiers edge.
@@ -61,15 +62,13 @@ type UserEdges struct {
loadedTypes [3]bool loadedTypes [3]bool
} }
// GroupOrErr returns the Group value or an error if the edge // GroupsOrErr returns the Groups value or an error if the edge
// was not loaded in eager-loading, or loaded but was not found. // was not loaded in eager-loading.
func (e UserEdges) GroupOrErr() (*Group, error) { func (e UserEdges) GroupsOrErr() ([]*Group, error) {
if e.Group != nil { if e.loadedTypes[0] {
return e.Group, nil return e.Groups, nil
} else if e.loadedTypes[0] {
return nil, &NotFoundError{label: group.Label}
} }
return nil, &NotLoadedError{edge: "group"} return nil, &NotLoadedError{edge: "groups"}
} }
// AuthTokensOrErr returns the AuthTokens value or an error if the edge // AuthTokensOrErr returns the AuthTokens value or an error if the edge
@@ -95,6 +94,8 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case user.FieldDefaultGroupID:
values[i] = &sql.NullScanner{S: new(uuid.UUID)}
case user.FieldIsSuperuser, user.FieldSuperuser: case user.FieldIsSuperuser, user.FieldSuperuser:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case user.FieldName, user.FieldEmail, user.FieldPassword, user.FieldRole, user.FieldOidcIssuer, user.FieldOidcSubject: case user.FieldName, user.FieldEmail, user.FieldPassword, user.FieldRole, user.FieldOidcIssuer, user.FieldOidcSubject:
@@ -195,6 +196,13 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.OidcSubject = new(string) _m.OidcSubject = new(string)
*_m.OidcSubject = value.String *_m.OidcSubject = value.String
} }
case user.FieldDefaultGroupID:
if value, ok := values[i].(*sql.NullScanner); !ok {
return fmt.Errorf("unexpected type %T for field default_group_id", values[i])
} else if value.Valid {
_m.DefaultGroupID = new(uuid.UUID)
*_m.DefaultGroupID = *value.S.(*uuid.UUID)
}
case user.ForeignKeys[0]: case user.ForeignKeys[0]:
if value, ok := values[i].(*sql.NullScanner); !ok { if value, ok := values[i].(*sql.NullScanner); !ok {
return fmt.Errorf("unexpected type %T for field group_users", values[i]) return fmt.Errorf("unexpected type %T for field group_users", values[i])
@@ -215,9 +223,9 @@ func (_m *User) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name) return _m.selectValues.Get(name)
} }
// QueryGroup queries the "group" edge of the User entity. // QueryGroups queries the "groups" edge of the User entity.
func (_m *User) QueryGroup() *GroupQuery { func (_m *User) QueryGroups() *GroupQuery {
return NewUserClient(_m.config).QueryGroup(_m) return NewUserClient(_m.config).QueryGroups(_m)
} }
// QueryAuthTokens queries the "auth_tokens" edge of the User entity. // QueryAuthTokens queries the "auth_tokens" edge of the User entity.
@@ -288,6 +296,11 @@ func (_m *User) String() string {
builder.WriteString("oidc_subject=") builder.WriteString("oidc_subject=")
builder.WriteString(*v) builder.WriteString(*v)
} }
builder.WriteString(", ")
if v := _m.DefaultGroupID; v != nil {
builder.WriteString("default_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

View File

@@ -38,21 +38,23 @@ const (
FieldOidcIssuer = "oidc_issuer" FieldOidcIssuer = "oidc_issuer"
// FieldOidcSubject holds the string denoting the oidc_subject field in the database. // FieldOidcSubject holds the string denoting the oidc_subject field in the database.
FieldOidcSubject = "oidc_subject" FieldOidcSubject = "oidc_subject"
// EdgeGroup holds the string denoting the group edge name in mutations. // FieldDefaultGroupID holds the string denoting the default_group_id field in the database.
EdgeGroup = "group" FieldDefaultGroupID = "default_group_id"
// EdgeGroups holds the string denoting the groups edge name in mutations.
EdgeGroups = "groups"
// EdgeAuthTokens holds the string denoting the auth_tokens edge name in mutations. // EdgeAuthTokens holds the string denoting the auth_tokens edge name in mutations.
EdgeAuthTokens = "auth_tokens" EdgeAuthTokens = "auth_tokens"
// EdgeNotifiers holds the string denoting the notifiers edge name in mutations. // EdgeNotifiers holds the string denoting the notifiers edge name in mutations.
EdgeNotifiers = "notifiers" EdgeNotifiers = "notifiers"
// Table holds the table name of the user in the database. // Table holds the table name of the user in the database.
Table = "users" Table = "users"
// GroupTable is the table that holds the group relation/edge. // GroupsTable is the table that holds the groups relation/edge.
GroupTable = "users" GroupsTable = "groups"
// GroupInverseTable is the table name for the Group entity. // GroupsInverseTable is the table name for the Group entity.
// It exists in this package in order to avoid circular dependency with the "group" package. // It exists in this package in order to avoid circular dependency with the "group" package.
GroupInverseTable = "groups" GroupsInverseTable = "groups"
// GroupColumn is the table column denoting the group relation/edge. // GroupsColumn is the table column denoting the groups relation/edge.
GroupColumn = "group_users" GroupsColumn = "user_groups"
// AuthTokensTable is the table that holds the auth_tokens relation/edge. // AuthTokensTable is the table that holds the auth_tokens relation/edge.
AuthTokensTable = "auth_tokens" AuthTokensTable = "auth_tokens"
// AuthTokensInverseTable is the table name for the AuthTokens entity. // AuthTokensInverseTable is the table name for the AuthTokens entity.
@@ -83,6 +85,7 @@ var Columns = []string{
FieldActivatedOn, FieldActivatedOn,
FieldOidcIssuer, FieldOidcIssuer,
FieldOidcSubject, FieldOidcSubject,
FieldDefaultGroupID,
} }
// ForeignKeys holds the SQL foreign-keys that are owned by the "users" // ForeignKeys holds the SQL foreign-keys that are owned by the "users"
@@ -216,10 +219,22 @@ func ByOidcSubject(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldOidcSubject, opts...).ToFunc() return sql.OrderByField(FieldOidcSubject, opts...).ToFunc()
} }
// ByGroupField orders the results by group field. // ByDefaultGroupID orders the results by the default_group_id field.
func ByGroupField(field string, opts ...sql.OrderTermOption) OrderOption { func ByDefaultGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultGroupID, opts...).ToFunc()
}
// ByGroupsCount orders the results by groups count.
func ByGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newGroupStep(), sql.OrderByField(field, opts...)) sqlgraph.OrderByNeighborsCount(s, newGroupsStep(), opts...)
}
}
// ByGroups orders the results by groups terms.
func ByGroups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
return func(s *sql.Selector) {
sqlgraph.OrderByNeighborTerms(s, newGroupsStep(), append([]sql.OrderTerm{term}, terms...)...)
} }
} }
@@ -250,11 +265,11 @@ func ByNotifiers(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
sqlgraph.OrderByNeighborTerms(s, newNotifiersStep(), append([]sql.OrderTerm{term}, terms...)...) sqlgraph.OrderByNeighborTerms(s, newNotifiersStep(), append([]sql.OrderTerm{term}, terms...)...)
} }
} }
func newGroupStep() *sqlgraph.Step { func newGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep( return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID), sqlgraph.From(Table, FieldID),
sqlgraph.To(GroupInverseTable, FieldID), sqlgraph.To(GroupsInverseTable, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), sqlgraph.Edge(sqlgraph.O2M, false, GroupsTable, GroupsColumn),
) )
} }
func newAuthTokensStep() *sqlgraph.Step { func newAuthTokensStep() *sqlgraph.Step {

View File

@@ -106,6 +106,11 @@ func OidcSubject(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldOidcSubject, v)) return predicate.User(sql.FieldEQ(FieldOidcSubject, v))
} }
// DefaultGroupID applies equality check predicate on the "default_group_id" field. It's identical to DefaultGroupIDEQ.
func DefaultGroupID(v uuid.UUID) predicate.User {
return predicate.User(sql.FieldEQ(FieldDefaultGroupID, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User { func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -631,21 +636,71 @@ func OidcSubjectContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldOidcSubject, v)) return predicate.User(sql.FieldContainsFold(FieldOidcSubject, v))
} }
// HasGroup applies the HasEdge predicate on the "group" edge. // DefaultGroupIDEQ applies the EQ predicate on the "default_group_id" field.
func HasGroup() predicate.User { func DefaultGroupIDEQ(v uuid.UUID) predicate.User {
return predicate.User(sql.FieldEQ(FieldDefaultGroupID, v))
}
// DefaultGroupIDNEQ applies the NEQ predicate on the "default_group_id" field.
func DefaultGroupIDNEQ(v uuid.UUID) predicate.User {
return predicate.User(sql.FieldNEQ(FieldDefaultGroupID, v))
}
// DefaultGroupIDIn applies the In predicate on the "default_group_id" field.
func DefaultGroupIDIn(vs ...uuid.UUID) predicate.User {
return predicate.User(sql.FieldIn(FieldDefaultGroupID, vs...))
}
// DefaultGroupIDNotIn applies the NotIn predicate on the "default_group_id" field.
func DefaultGroupIDNotIn(vs ...uuid.UUID) predicate.User {
return predicate.User(sql.FieldNotIn(FieldDefaultGroupID, vs...))
}
// DefaultGroupIDGT applies the GT predicate on the "default_group_id" field.
func DefaultGroupIDGT(v uuid.UUID) predicate.User {
return predicate.User(sql.FieldGT(FieldDefaultGroupID, v))
}
// DefaultGroupIDGTE applies the GTE predicate on the "default_group_id" field.
func DefaultGroupIDGTE(v uuid.UUID) predicate.User {
return predicate.User(sql.FieldGTE(FieldDefaultGroupID, v))
}
// DefaultGroupIDLT applies the LT predicate on the "default_group_id" field.
func DefaultGroupIDLT(v uuid.UUID) predicate.User {
return predicate.User(sql.FieldLT(FieldDefaultGroupID, v))
}
// DefaultGroupIDLTE applies the LTE predicate on the "default_group_id" field.
func DefaultGroupIDLTE(v uuid.UUID) predicate.User {
return predicate.User(sql.FieldLTE(FieldDefaultGroupID, v))
}
// DefaultGroupIDIsNil applies the IsNil predicate on the "default_group_id" field.
func DefaultGroupIDIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldDefaultGroupID))
}
// DefaultGroupIDNotNil applies the NotNil predicate on the "default_group_id" field.
func DefaultGroupIDNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldDefaultGroupID))
}
// HasGroups applies the HasEdge predicate on the "groups" edge.
func HasGroups() predicate.User {
return predicate.User(func(s *sql.Selector) { return predicate.User(func(s *sql.Selector) {
step := sqlgraph.NewStep( step := sqlgraph.NewStep(
sqlgraph.From(Table, FieldID), sqlgraph.From(Table, FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, GroupTable, GroupColumn), sqlgraph.Edge(sqlgraph.O2M, false, GroupsTable, GroupsColumn),
) )
sqlgraph.HasNeighbors(s, step) sqlgraph.HasNeighbors(s, step)
}) })
} }
// HasGroupWith applies the HasEdge predicate on the "group" edge with a given conditions (other predicates). // HasGroupsWith applies the HasEdge predicate on the "groups" edge with a given conditions (other predicates).
func HasGroupWith(preds ...predicate.Group) predicate.User { func HasGroupsWith(preds ...predicate.Group) predicate.User {
return predicate.User(func(s *sql.Selector) { return predicate.User(func(s *sql.Selector) {
step := newGroupStep() step := newGroupsStep()
sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
for _, p := range preds { for _, p := range preds {
p(s) p(s)

View File

@@ -162,6 +162,20 @@ func (_c *UserCreate) SetNillableOidcSubject(v *string) *UserCreate {
return _c return _c
} }
// SetDefaultGroupID sets the "default_group_id" field.
func (_c *UserCreate) SetDefaultGroupID(v uuid.UUID) *UserCreate {
_c.mutation.SetDefaultGroupID(v)
return _c
}
// SetNillableDefaultGroupID sets the "default_group_id" field if the given value is not nil.
func (_c *UserCreate) SetNillableDefaultGroupID(v *uuid.UUID) *UserCreate {
if v != nil {
_c.SetDefaultGroupID(*v)
}
return _c
}
// SetID sets the "id" field. // SetID sets the "id" field.
func (_c *UserCreate) SetID(v uuid.UUID) *UserCreate { func (_c *UserCreate) SetID(v uuid.UUID) *UserCreate {
_c.mutation.SetID(v) _c.mutation.SetID(v)
@@ -176,15 +190,19 @@ func (_c *UserCreate) SetNillableID(v *uuid.UUID) *UserCreate {
return _c return _c
} }
// SetGroupID sets the "group" edge to the Group entity by ID. // AddGroupIDs adds the "groups" edge to the Group entity by IDs.
func (_c *UserCreate) SetGroupID(id uuid.UUID) *UserCreate { func (_c *UserCreate) AddGroupIDs(ids ...uuid.UUID) *UserCreate {
_c.mutation.SetGroupID(id) _c.mutation.AddGroupIDs(ids...)
return _c return _c
} }
// SetGroup sets the "group" edge to the Group entity. // AddGroups adds the "groups" edges to the Group entity.
func (_c *UserCreate) SetGroup(v *Group) *UserCreate { func (_c *UserCreate) AddGroups(v ...*Group) *UserCreate {
return _c.SetGroupID(v.ID) ids := make([]uuid.UUID, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _c.AddGroupIDs(ids...)
} }
// AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by IDs. // AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by IDs.
@@ -321,9 +339,6 @@ func (_c *UserCreate) check() error {
return &ValidationError{Name: "role", err: fmt.Errorf(`ent: validator failed for field "User.role": %w`, err)} return &ValidationError{Name: "role", err: fmt.Errorf(`ent: validator failed for field "User.role": %w`, err)}
} }
} }
if len(_c.mutation.GroupIDs()) == 0 {
return &ValidationError{Name: "group", err: errors.New(`ent: missing required edge "User.group"`)}
}
return nil return nil
} }
@@ -403,12 +418,16 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldOidcSubject, field.TypeString, value) _spec.SetField(user.FieldOidcSubject, field.TypeString, value)
_node.OidcSubject = &value _node.OidcSubject = &value
} }
if nodes := _c.mutation.GroupIDs(); len(nodes) > 0 { if value, ok := _c.mutation.DefaultGroupID(); ok {
_spec.SetField(user.FieldDefaultGroupID, field.TypeUUID, value)
_node.DefaultGroupID = &value
}
if nodes := _c.mutation.GroupsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.O2M,
Inverse: true, Inverse: false,
Table: user.GroupTable, Table: user.GroupsTable,
Columns: []string{user.GroupColumn}, Columns: []string{user.GroupsColumn},
Bidi: false, Bidi: false,
Target: &sqlgraph.EdgeTarget{ Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID), IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID),
@@ -417,7 +436,6 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
for _, k := range nodes { for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k) edge.Target.Nodes = append(edge.Target.Nodes, k)
} }
_node.group_users = &nodes[0]
_spec.Edges = append(_spec.Edges, edge) _spec.Edges = append(_spec.Edges, edge)
} }
if nodes := _c.mutation.AuthTokensIDs(); len(nodes) > 0 { if nodes := _c.mutation.AuthTokensIDs(); len(nodes) > 0 {

View File

@@ -27,7 +27,7 @@ type UserQuery struct {
order []user.OrderOption order []user.OrderOption
inters []Interceptor inters []Interceptor
predicates []predicate.User predicates []predicate.User
withGroup *GroupQuery withGroups *GroupQuery
withAuthTokens *AuthTokensQuery withAuthTokens *AuthTokensQuery
withNotifiers *NotifierQuery withNotifiers *NotifierQuery
withFKs bool withFKs bool
@@ -67,8 +67,8 @@ func (_q *UserQuery) Order(o ...user.OrderOption) *UserQuery {
return _q return _q
} }
// QueryGroup chains the current query on the "group" edge. // QueryGroups chains the current query on the "groups" edge.
func (_q *UserQuery) QueryGroup() *GroupQuery { func (_q *UserQuery) QueryGroups() *GroupQuery {
query := (&GroupClient{config: _q.config}).Query() query := (&GroupClient{config: _q.config}).Query()
query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
if err := _q.prepareQuery(ctx); err != nil { if err := _q.prepareQuery(ctx); err != nil {
@@ -81,7 +81,7 @@ func (_q *UserQuery) QueryGroup() *GroupQuery {
step := sqlgraph.NewStep( step := sqlgraph.NewStep(
sqlgraph.From(user.Table, user.FieldID, selector), sqlgraph.From(user.Table, user.FieldID, selector),
sqlgraph.To(group.Table, group.FieldID), sqlgraph.To(group.Table, group.FieldID),
sqlgraph.Edge(sqlgraph.M2O, true, user.GroupTable, user.GroupColumn), sqlgraph.Edge(sqlgraph.O2M, false, user.GroupsTable, user.GroupsColumn),
) )
fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
return fromU, nil return fromU, nil
@@ -325,7 +325,7 @@ func (_q *UserQuery) Clone() *UserQuery {
order: append([]user.OrderOption{}, _q.order...), order: append([]user.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...), inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.User{}, _q.predicates...), predicates: append([]predicate.User{}, _q.predicates...),
withGroup: _q.withGroup.Clone(), withGroups: _q.withGroups.Clone(),
withAuthTokens: _q.withAuthTokens.Clone(), withAuthTokens: _q.withAuthTokens.Clone(),
withNotifiers: _q.withNotifiers.Clone(), withNotifiers: _q.withNotifiers.Clone(),
// clone intermediate query. // clone intermediate query.
@@ -334,14 +334,14 @@ func (_q *UserQuery) Clone() *UserQuery {
} }
} }
// WithGroup tells the query-builder to eager-load the nodes that are connected to // WithGroups tells the query-builder to eager-load the nodes that are connected to
// the "group" edge. The optional arguments are used to configure the query builder of the edge. // the "groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithGroup(opts ...func(*GroupQuery)) *UserQuery { func (_q *UserQuery) WithGroups(opts ...func(*GroupQuery)) *UserQuery {
query := (&GroupClient{config: _q.config}).Query() query := (&GroupClient{config: _q.config}).Query()
for _, opt := range opts { for _, opt := range opts {
opt(query) opt(query)
} }
_q.withGroup = query _q.withGroups = query
return _q return _q
} }
@@ -447,14 +447,11 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
withFKs = _q.withFKs withFKs = _q.withFKs
_spec = _q.querySpec() _spec = _q.querySpec()
loadedTypes = [3]bool{ loadedTypes = [3]bool{
_q.withGroup != nil, _q.withGroups != nil,
_q.withAuthTokens != nil, _q.withAuthTokens != nil,
_q.withNotifiers != nil, _q.withNotifiers != nil,
} }
) )
if _q.withGroup != nil {
withFKs = true
}
if withFKs { if withFKs {
_spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...) _spec.Node.Columns = append(_spec.Node.Columns, user.ForeignKeys...)
} }
@@ -476,9 +473,10 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
if len(nodes) == 0 { if len(nodes) == 0 {
return nodes, nil return nodes, nil
} }
if query := _q.withGroup; query != nil { if query := _q.withGroups; query != nil {
if err := _q.loadGroup(ctx, query, nodes, nil, if err := _q.loadGroups(ctx, query, nodes,
func(n *User, e *Group) { n.Edges.Group = e }); err != nil { func(n *User) { n.Edges.Groups = []*Group{} },
func(n *User, e *Group) { n.Edges.Groups = append(n.Edges.Groups, e) }); err != nil {
return nil, err return nil, err
} }
} }
@@ -499,35 +497,34 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nodes, nil return nodes, nil
} }
func (_q *UserQuery) loadGroup(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error { func (_q *UserQuery) loadGroups(ctx context.Context, query *GroupQuery, nodes []*User, init func(*User), assign func(*User, *Group)) error {
ids := make([]uuid.UUID, 0, len(nodes)) fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[uuid.UUID][]*User) nodeids := make(map[uuid.UUID]*User)
for i := range nodes { for i := range nodes {
if nodes[i].group_users == nil { fks = append(fks, nodes[i].ID)
continue nodeids[nodes[i].ID] = nodes[i]
if init != nil {
init(nodes[i])
} }
fk := *nodes[i].group_users
if _, ok := nodeids[fk]; !ok {
ids = append(ids, fk)
}
nodeids[fk] = append(nodeids[fk], nodes[i])
} }
if len(ids) == 0 { query.withFKs = true
return nil query.Where(predicate.Group(func(s *sql.Selector) {
} s.Where(sql.InValues(s.C(user.GroupsColumn), fks...))
query.Where(group.IDIn(ids...)) }))
neighbors, err := query.All(ctx) neighbors, err := query.All(ctx)
if err != nil { if err != nil {
return err return err
} }
for _, n := range neighbors { for _, n := range neighbors {
nodes, ok := nodeids[n.ID] fk := n.user_groups
if fk == nil {
return fmt.Errorf(`foreign-key "user_groups" is nil for node %v`, n.ID)
}
node, ok := nodeids[*fk]
if !ok { if !ok {
return fmt.Errorf(`unexpected foreign-key "group_users" returned %v`, n.ID) return fmt.Errorf(`unexpected referenced foreign-key "user_groups" returned %v for node %v`, *fk, n.ID)
}
for i := range nodes {
assign(nodes[i], n)
} }
assign(node, n)
} }
return nil return nil
} }

View File

@@ -188,15 +188,39 @@ func (_u *UserUpdate) ClearOidcSubject() *UserUpdate {
return _u return _u
} }
// SetGroupID sets the "group" edge to the Group entity by ID. // SetDefaultGroupID sets the "default_group_id" field.
func (_u *UserUpdate) SetGroupID(id uuid.UUID) *UserUpdate { func (_u *UserUpdate) SetDefaultGroupID(v uuid.UUID) *UserUpdate {
_u.mutation.SetGroupID(id) _u.mutation.SetDefaultGroupID(v)
return _u return _u
} }
// SetGroup sets the "group" edge to the Group entity. // SetNillableDefaultGroupID sets the "default_group_id" field if the given value is not nil.
func (_u *UserUpdate) SetGroup(v *Group) *UserUpdate { func (_u *UserUpdate) SetNillableDefaultGroupID(v *uuid.UUID) *UserUpdate {
return _u.SetGroupID(v.ID) if v != nil {
_u.SetDefaultGroupID(*v)
}
return _u
}
// ClearDefaultGroupID clears the value of the "default_group_id" field.
func (_u *UserUpdate) ClearDefaultGroupID() *UserUpdate {
_u.mutation.ClearDefaultGroupID()
return _u
}
// AddGroupIDs adds the "groups" edge to the Group entity by IDs.
func (_u *UserUpdate) AddGroupIDs(ids ...uuid.UUID) *UserUpdate {
_u.mutation.AddGroupIDs(ids...)
return _u
}
// AddGroups adds the "groups" edges to the Group entity.
func (_u *UserUpdate) AddGroups(v ...*Group) *UserUpdate {
ids := make([]uuid.UUID, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddGroupIDs(ids...)
} }
// AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by IDs. // AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by IDs.
@@ -234,12 +258,27 @@ func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation return _u.mutation
} }
// ClearGroup clears the "group" edge to the Group entity. // ClearGroups clears all "groups" edges to the Group entity.
func (_u *UserUpdate) ClearGroup() *UserUpdate { func (_u *UserUpdate) ClearGroups() *UserUpdate {
_u.mutation.ClearGroup() _u.mutation.ClearGroups()
return _u return _u
} }
// RemoveGroupIDs removes the "groups" edge to Group entities by IDs.
func (_u *UserUpdate) RemoveGroupIDs(ids ...uuid.UUID) *UserUpdate {
_u.mutation.RemoveGroupIDs(ids...)
return _u
}
// RemoveGroups removes "groups" edges to Group entities.
func (_u *UserUpdate) RemoveGroups(v ...*Group) *UserUpdate {
ids := make([]uuid.UUID, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveGroupIDs(ids...)
}
// ClearAuthTokens clears all "auth_tokens" edges to the AuthTokens entity. // ClearAuthTokens clears all "auth_tokens" edges to the AuthTokens entity.
func (_u *UserUpdate) ClearAuthTokens() *UserUpdate { func (_u *UserUpdate) ClearAuthTokens() *UserUpdate {
_u.mutation.ClearAuthTokens() _u.mutation.ClearAuthTokens()
@@ -340,9 +379,6 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "role", err: fmt.Errorf(`ent: validator failed for field "User.role": %w`, err)} return &ValidationError{Name: "role", err: fmt.Errorf(`ent: validator failed for field "User.role": %w`, err)}
} }
} }
if _u.mutation.GroupCleared() && len(_u.mutation.GroupIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "User.group"`)
}
return nil return nil
} }
@@ -400,12 +436,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.OidcSubjectCleared() { if _u.mutation.OidcSubjectCleared() {
_spec.ClearField(user.FieldOidcSubject, field.TypeString) _spec.ClearField(user.FieldOidcSubject, field.TypeString)
} }
if _u.mutation.GroupCleared() { if value, ok := _u.mutation.DefaultGroupID(); ok {
_spec.SetField(user.FieldDefaultGroupID, field.TypeUUID, value)
}
if _u.mutation.DefaultGroupIDCleared() {
_spec.ClearField(user.FieldDefaultGroupID, field.TypeUUID)
}
if _u.mutation.GroupsCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.O2M,
Inverse: true, Inverse: false,
Table: user.GroupTable, Table: user.GroupsTable,
Columns: []string{user.GroupColumn}, Columns: []string{user.GroupsColumn},
Bidi: false, Bidi: false,
Target: &sqlgraph.EdgeTarget{ Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID), IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID),
@@ -413,12 +455,28 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
} }
_spec.Edges.Clear = append(_spec.Edges.Clear, edge) _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
} }
if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 { if nodes := _u.mutation.RemovedGroupsIDs(); len(nodes) > 0 && !_u.mutation.GroupsCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.O2M,
Inverse: true, Inverse: false,
Table: user.GroupTable, Table: user.GroupsTable,
Columns: []string{user.GroupColumn}, Columns: []string{user.GroupsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.GroupsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.GroupsTable,
Columns: []string{user.GroupsColumn},
Bidi: false, Bidi: false,
Target: &sqlgraph.EdgeTarget{ Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID), IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID),
@@ -695,15 +753,39 @@ func (_u *UserUpdateOne) ClearOidcSubject() *UserUpdateOne {
return _u return _u
} }
// SetGroupID sets the "group" edge to the Group entity by ID. // SetDefaultGroupID sets the "default_group_id" field.
func (_u *UserUpdateOne) SetGroupID(id uuid.UUID) *UserUpdateOne { func (_u *UserUpdateOne) SetDefaultGroupID(v uuid.UUID) *UserUpdateOne {
_u.mutation.SetGroupID(id) _u.mutation.SetDefaultGroupID(v)
return _u return _u
} }
// SetGroup sets the "group" edge to the Group entity. // SetNillableDefaultGroupID sets the "default_group_id" field if the given value is not nil.
func (_u *UserUpdateOne) SetGroup(v *Group) *UserUpdateOne { func (_u *UserUpdateOne) SetNillableDefaultGroupID(v *uuid.UUID) *UserUpdateOne {
return _u.SetGroupID(v.ID) if v != nil {
_u.SetDefaultGroupID(*v)
}
return _u
}
// ClearDefaultGroupID clears the value of the "default_group_id" field.
func (_u *UserUpdateOne) ClearDefaultGroupID() *UserUpdateOne {
_u.mutation.ClearDefaultGroupID()
return _u
}
// AddGroupIDs adds the "groups" edge to the Group entity by IDs.
func (_u *UserUpdateOne) AddGroupIDs(ids ...uuid.UUID) *UserUpdateOne {
_u.mutation.AddGroupIDs(ids...)
return _u
}
// AddGroups adds the "groups" edges to the Group entity.
func (_u *UserUpdateOne) AddGroups(v ...*Group) *UserUpdateOne {
ids := make([]uuid.UUID, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.AddGroupIDs(ids...)
} }
// AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by IDs. // AddAuthTokenIDs adds the "auth_tokens" edge to the AuthTokens entity by IDs.
@@ -741,12 +823,27 @@ func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation return _u.mutation
} }
// ClearGroup clears the "group" edge to the Group entity. // ClearGroups clears all "groups" edges to the Group entity.
func (_u *UserUpdateOne) ClearGroup() *UserUpdateOne { func (_u *UserUpdateOne) ClearGroups() *UserUpdateOne {
_u.mutation.ClearGroup() _u.mutation.ClearGroups()
return _u return _u
} }
// RemoveGroupIDs removes the "groups" edge to Group entities by IDs.
func (_u *UserUpdateOne) RemoveGroupIDs(ids ...uuid.UUID) *UserUpdateOne {
_u.mutation.RemoveGroupIDs(ids...)
return _u
}
// RemoveGroups removes "groups" edges to Group entities.
func (_u *UserUpdateOne) RemoveGroups(v ...*Group) *UserUpdateOne {
ids := make([]uuid.UUID, len(v))
for i := range v {
ids[i] = v[i].ID
}
return _u.RemoveGroupIDs(ids...)
}
// ClearAuthTokens clears all "auth_tokens" edges to the AuthTokens entity. // ClearAuthTokens clears all "auth_tokens" edges to the AuthTokens entity.
func (_u *UserUpdateOne) ClearAuthTokens() *UserUpdateOne { func (_u *UserUpdateOne) ClearAuthTokens() *UserUpdateOne {
_u.mutation.ClearAuthTokens() _u.mutation.ClearAuthTokens()
@@ -860,9 +957,6 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "role", err: fmt.Errorf(`ent: validator failed for field "User.role": %w`, err)} return &ValidationError{Name: "role", err: fmt.Errorf(`ent: validator failed for field "User.role": %w`, err)}
} }
} }
if _u.mutation.GroupCleared() && len(_u.mutation.GroupIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "User.group"`)
}
return nil return nil
} }
@@ -937,12 +1031,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.OidcSubjectCleared() { if _u.mutation.OidcSubjectCleared() {
_spec.ClearField(user.FieldOidcSubject, field.TypeString) _spec.ClearField(user.FieldOidcSubject, field.TypeString)
} }
if _u.mutation.GroupCleared() { if value, ok := _u.mutation.DefaultGroupID(); ok {
_spec.SetField(user.FieldDefaultGroupID, field.TypeUUID, value)
}
if _u.mutation.DefaultGroupIDCleared() {
_spec.ClearField(user.FieldDefaultGroupID, field.TypeUUID)
}
if _u.mutation.GroupsCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.O2M,
Inverse: true, Inverse: false,
Table: user.GroupTable, Table: user.GroupsTable,
Columns: []string{user.GroupColumn}, Columns: []string{user.GroupsColumn},
Bidi: false, Bidi: false,
Target: &sqlgraph.EdgeTarget{ Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID), IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID),
@@ -950,12 +1050,28 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
} }
_spec.Edges.Clear = append(_spec.Edges.Clear, edge) _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
} }
if nodes := _u.mutation.GroupIDs(); len(nodes) > 0 { if nodes := _u.mutation.RemovedGroupsIDs(); len(nodes) > 0 && !_u.mutation.GroupsCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O, Rel: sqlgraph.O2M,
Inverse: true, Inverse: false,
Table: user.GroupTable, Table: user.GroupsTable,
Columns: []string{user.GroupColumn}, Columns: []string{user.GroupsColumn},
Bidi: false,
Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID),
},
}
for _, k := range nodes {
edge.Target.Nodes = append(edge.Target.Nodes, k)
}
_spec.Edges.Clear = append(_spec.Edges.Clear, edge)
}
if nodes := _u.mutation.GroupsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
Inverse: false,
Table: user.GroupsTable,
Columns: []string{user.GroupsColumn},
Bidi: false, Bidi: false,
Target: &sqlgraph.EdgeTarget{ Target: &sqlgraph.EdgeTarget{
IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID), IDSpec: sqlgraph.NewFieldSpec(group.FieldID, field.TypeUUID),

View File

@@ -0,0 +1,47 @@
-- +goose Up
-- Create user_groups junction table for M:M relationship
CREATE TABLE IF NOT EXISTS "user_groups" (
"user_id" uuid NOT NULL,
"group_id" uuid NOT NULL,
PRIMARY KEY ("user_id", "group_id"),
CONSTRAINT "user_groups_user_id" FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE CASCADE,
CONSTRAINT "user_groups_group_id" FOREIGN KEY ("group_id") REFERENCES "groups" ("id") ON UPDATE NO ACTION ON DELETE CASCADE
);
-- Migrate existing user->group relationships to the junction table
INSERT INTO "user_groups" ("user_id", "group_id")
SELECT "id", "group_users" FROM "users" WHERE "group_users" IS NOT NULL;
-- Add default_group_id column to users table
ALTER TABLE "users" ADD COLUMN "default_group_id" uuid;
-- Set default_group_id to the user's current group
UPDATE "users" SET "default_group_id" = "group_users" WHERE "group_users" IS NOT NULL;
-- Drop the old group_users foreign key constraint and column
ALTER TABLE "users" DROP CONSTRAINT "users_groups_users";
ALTER TABLE "users" DROP COLUMN "group_users";
-- Add foreign key constraint for default_group_id
ALTER TABLE "users" ADD CONSTRAINT "users_groups_users_default" FOREIGN KEY ("default_group_id") REFERENCES "groups" ("id") ON UPDATE NO ACTION ON DELETE SET NULL;
-- +goose Down
-- Recreate group_users column with foreign key
ALTER TABLE "users" ADD COLUMN "group_users" uuid;
-- Restore the group_users values from user_groups (using the default_group_id or first entry)
UPDATE "users"
SET "group_users" = COALESCE("default_group_id", (
SELECT "group_id" FROM "user_groups" WHERE "user_id" = "users"."id" LIMIT 1
));
-- Drop the default_group_id foreign key and column
ALTER TABLE "users" DROP CONSTRAINT "users_groups_users_default";
ALTER TABLE "users" DROP COLUMN "default_group_id";
-- Add back the original foreign key constraint
ALTER TABLE "users" ADD CONSTRAINT "users_groups_users" FOREIGN KEY ("group_users") REFERENCES "groups" ("id") ON UPDATE NO ACTION ON DELETE CASCADE;
-- Drop the junction table
DROP TABLE IF EXISTS "user_groups";

View File

@@ -0,0 +1,105 @@
-- +goose Up
-- Create user_groups junction table for M:M relationship
CREATE TABLE IF NOT EXISTS user_groups (
user_id UUID NOT NULL,
group_id UUID NOT NULL,
PRIMARY KEY (user_id, group_id),
CONSTRAINT user_groups_user_id FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
CONSTRAINT user_groups_group_id FOREIGN KEY (group_id) REFERENCES groups(id) ON DELETE CASCADE
);
-- Migrate existing user->group relationships to the junction table
INSERT INTO user_groups (user_id, group_id)
SELECT id, group_users FROM users WHERE group_users IS NOT NULL;
-- Add default_group_id column to users table
ALTER TABLE users ADD COLUMN default_group_id UUID;
-- Set default_group_id to the user's current group
UPDATE users SET default_group_id = group_users WHERE group_users IS NOT NULL;
-- Add foreign key constraint for default_group_id
CREATE TABLE users_new (
id UUID NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
name TEXT NOT NULL,
email TEXT NOT NULL UNIQUE,
password TEXT,
is_superuser BOOLEAN NOT NULL DEFAULT false,
superuser BOOLEAN NOT NULL DEFAULT false,
role TEXT NOT NULL DEFAULT 'user',
activated_on DATETIME,
oidc_issuer TEXT,
oidc_subject TEXT,
default_group_id UUID,
PRIMARY KEY (id),
CONSTRAINT users_groups_users_default FOREIGN KEY (default_group_id) REFERENCES groups(id) ON DELETE SET NULL,
UNIQUE (oidc_issuer, oidc_subject)
);
-- Copy data from old table to new table
INSERT INTO users_new (
id, created_at, updated_at, name, email, password, is_superuser, superuser, role,
activated_on, oidc_issuer, oidc_subject, default_group_id
)
SELECT
id, created_at, updated_at, name, email, password, is_superuser, superuser, role,
activated_on, oidc_issuer, oidc_subject, default_group_id
FROM users;
-- Drop old indexes
DROP INDEX IF EXISTS users_email_key;
DROP INDEX IF EXISTS users_oidc_issuer_subject_key;
-- Drop old table
DROP TABLE users;
-- Rename new table to users
ALTER TABLE users_new RENAME TO users;
-- Recreate indexes
CREATE UNIQUE INDEX IF NOT EXISTS users_email_key ON users(email);
CREATE UNIQUE INDEX IF NOT EXISTS users_oidc_issuer_subject_key ON users(oidc_issuer, oidc_subject);
-- +goose Down
-- Recreate the old schema
CREATE TABLE users_old (
id UUID NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
name TEXT NOT NULL,
email TEXT NOT NULL UNIQUE,
password TEXT,
is_superuser BOOLEAN NOT NULL DEFAULT false,
superuser BOOLEAN NOT NULL DEFAULT false,
role TEXT NOT NULL DEFAULT 'user',
activated_on DATETIME,
oidc_issuer TEXT,
oidc_subject TEXT,
group_users UUID NOT NULL,
PRIMARY KEY (id),
CONSTRAINT users_groups_users FOREIGN KEY (group_users) REFERENCES groups(id) ON DELETE CASCADE,
UNIQUE (oidc_issuer, oidc_subject)
);
-- Copy data back, using the first group from user_groups
INSERT INTO users_old (
id, created_at, updated_at, name, email, password, is_superuser, superuser, role,
activated_on, oidc_issuer, oidc_subject, group_users
)
SELECT
u.id, u.created_at, u.updated_at, u.name, u.email, u.password, u.is_superuser, u.superuser, u.role,
u.activated_on, u.oidc_issuer, u.oidc_subject, COALESCE(u.default_group_id, (SELECT group_id FROM user_groups WHERE user_id = u.id LIMIT 1))
FROM users u;
DROP INDEX IF EXISTS users_email_key;
DROP INDEX IF EXISTS users_oidc_issuer_subject_key;
DROP TABLE users;
ALTER TABLE users_old RENAME TO users;
CREATE UNIQUE INDEX IF NOT EXISTS users_email_key ON users(email);
CREATE UNIQUE INDEX IF NOT EXISTS users_oidc_issuer_subject_key ON users(oidc_issuer, oidc_subject);
DROP TABLE IF EXISTS user_groups;

View File

@@ -313,7 +313,7 @@ func TestItemRepository_GetAllCustomFields(t *testing.T) {
// Test getting all values from field // Test getting all values from field
{ {
results, err := tRepos.Items.GetAllCustomFieldValues(context.Background(), tUser.GroupID, names[0]) results, err := tRepos.Items.GetAllCustomFieldValues(context.Background(), tUser.DefaultGroupID, names[0])
require.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, values[:1], results) assert.ElementsMatch(t, values[:1], results)
@@ -397,5 +397,3 @@ func TestItemsRepository_DeleteByGroupWithAttachments(t *testing.T) {
_, err = tRepos.Attachments.Get(context.Background(), tGroup.ID, attachment.ID) _, err = tRepos.Attachments.Get(context.Background(), tGroup.ID, attachment.ID)
require.Error(t, err) require.Error(t, err)
} }

View File

@@ -40,7 +40,7 @@ func (r *TokenRepository) GetUserFromToken(ctx context.Context, token []byte) (U
Where(authtokens.ExpiresAtGTE(time.Now())). Where(authtokens.ExpiresAtGTE(time.Now())).
WithUser(). WithUser().
QueryUser(). QueryUser().
WithGroup(). WithGroups().
Only(ctx) Only(ctx)
if err != nil { if err != nil {
return UserOut{}, err return UserOut{}, err

View File

@@ -17,12 +17,12 @@ type (
// in the database. It should to create users from an API unless the user has // in the database. It should to create users from an API unless the user has
// rights to create SuperUsers. For regular user in data use the UserIn struct. // rights to create SuperUsers. For regular user in data use the UserIn struct.
UserCreate struct { UserCreate struct {
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
Password *string `json:"password"` Password *string `json:"password"`
IsSuperuser bool `json:"isSuperuser"` IsSuperuser bool `json:"isSuperuser"`
GroupID uuid.UUID `json:"groupID"` DefaultGroupID uuid.UUID `json:"defaultGroupID"`
IsOwner bool `json:"isOwner"` IsOwner bool `json:"isOwner"`
} }
UserUpdate struct { UserUpdate struct {
@@ -31,16 +31,16 @@ type (
} }
UserOut struct { UserOut struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
IsSuperuser bool `json:"isSuperuser"` IsSuperuser bool `json:"isSuperuser"`
GroupID uuid.UUID `json:"groupId"` DefaultGroupID uuid.UUID `json:"defaultGroupId"`
GroupName string `json:"groupName"` GroupIDs []uuid.UUID `json:"groupIds"`
PasswordHash string `json:"-"` PasswordHash string `json:"-"`
IsOwner bool `json:"isOwner"` IsOwner bool `json:"isOwner"`
OidcIssuer *string `json:"oidcIssuer"` OidcIssuer *string `json:"oidcIssuer"`
OidcSubject *string `json:"oidcSubject"` OidcSubject *string `json:"oidcSubject"`
} }
) )
@@ -55,37 +55,48 @@ func mapUserOut(user *ent.User) UserOut {
passwordHash = *user.Password passwordHash = *user.Password
} }
groupIDs := make([]uuid.UUID, len(user.Edges.Groups))
for i, g := range user.Edges.Groups {
groupIDs[i] = g.ID
}
// Get the default group ID, handling the optional pointer
defaultGroupID := uuid.Nil
if user.DefaultGroupID != nil {
defaultGroupID = *user.DefaultGroupID
}
return UserOut{ return UserOut{
ID: user.ID, ID: user.ID,
Name: user.Name, Name: user.Name,
Email: user.Email, Email: user.Email,
IsSuperuser: user.IsSuperuser, IsSuperuser: user.IsSuperuser,
GroupID: user.Edges.Group.ID, DefaultGroupID: defaultGroupID,
GroupName: user.Edges.Group.Name, GroupIDs: groupIDs,
PasswordHash: passwordHash, PasswordHash: passwordHash,
IsOwner: user.Role == "owner", IsOwner: user.Role == "owner",
OidcIssuer: user.OidcIssuer, OidcIssuer: user.OidcIssuer,
OidcSubject: user.OidcSubject, OidcSubject: user.OidcSubject,
} }
} }
func (r *UserRepository) GetOneID(ctx context.Context, id uuid.UUID) (UserOut, error) { func (r *UserRepository) GetOneID(ctx context.Context, id uuid.UUID) (UserOut, error) {
return mapUserOutErr(r.db.User.Query(). return mapUserOutErr(r.db.User.Query().
Where(user.ID(id)). Where(user.ID(id)).
WithGroup(). WithGroups().
Only(ctx)) Only(ctx))
} }
func (r *UserRepository) GetOneEmail(ctx context.Context, email string) (UserOut, error) { func (r *UserRepository) GetOneEmail(ctx context.Context, email string) (UserOut, error) {
return mapUserOutErr(r.db.User.Query(). return mapUserOutErr(r.db.User.Query().
Where(user.EmailEqualFold(email)). Where(user.EmailEqualFold(email)).
WithGroup(). WithGroups().
Only(ctx), Only(ctx),
) )
} }
func (r *UserRepository) GetAll(ctx context.Context) ([]UserOut, error) { func (r *UserRepository) GetAll(ctx context.Context) ([]UserOut, error) {
return mapUsersOutErr(r.db.User.Query().WithGroup().All(ctx)) return mapUsersOutErr(r.db.User.Query().WithGroups().All(ctx))
} }
func (r *UserRepository) Create(ctx context.Context, usr UserCreate) (UserOut, error) { func (r *UserRepository) Create(ctx context.Context, usr UserCreate) (UserOut, error) {
@@ -99,8 +110,9 @@ func (r *UserRepository) Create(ctx context.Context, usr UserCreate) (UserOut, e
SetName(usr.Name). SetName(usr.Name).
SetEmail(usr.Email). SetEmail(usr.Email).
SetIsSuperuser(usr.IsSuperuser). SetIsSuperuser(usr.IsSuperuser).
SetGroupID(usr.GroupID). SetDefaultGroupID(usr.DefaultGroupID).
SetRole(role) SetRole(role).
AddGroupIDs(usr.DefaultGroupID)
// Only set password if provided (non-nil) // Only set password if provided (non-nil)
if usr.Password != nil { if usr.Password != nil {
@@ -126,10 +138,11 @@ func (r *UserRepository) CreateWithOIDC(ctx context.Context, usr UserCreate, iss
SetName(usr.Name). SetName(usr.Name).
SetEmail(usr.Email). SetEmail(usr.Email).
SetIsSuperuser(usr.IsSuperuser). SetIsSuperuser(usr.IsSuperuser).
SetGroupID(usr.GroupID). SetDefaultGroupID(usr.DefaultGroupID).
SetRole(role). SetRole(role).
SetOidcIssuer(issuer). SetOidcIssuer(issuer).
SetOidcSubject(subject) SetOidcSubject(subject).
AddGroupIDs(usr.DefaultGroupID)
if usr.Password != nil { if usr.Password != nil {
createQuery = createQuery.SetPassword(*usr.Password) createQuery = createQuery.SetPassword(*usr.Password)
@@ -183,6 +196,6 @@ func (r *UserRepository) SetOIDCIdentity(ctx context.Context, uid uuid.UUID, iss
func (r *UserRepository) GetOneOIDC(ctx context.Context, issuer, subject string) (UserOut, error) { func (r *UserRepository) GetOneOIDC(ctx context.Context, issuer, subject string) (UserOut, error) {
return mapUserOutErr(r.db.User.Query(). return mapUserOutErr(r.db.User.Query().
Where(user.OidcIssuerEQ(issuer), user.OidcSubjectEQ(subject)). Where(user.OidcIssuerEQ(issuer), user.OidcSubjectEQ(subject)).
WithGroup(). WithGroups().
Only(ctx)) Only(ctx))
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"testing" "testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -11,11 +12,11 @@ import (
func userFactory() UserCreate { func userFactory() UserCreate {
password := fk.Str(10) password := fk.Str(10)
return UserCreate{ return UserCreate{
Name: fk.Str(10), Name: fk.Str(10),
Email: fk.Email(), Email: fk.Email(),
Password: &password, Password: &password,
IsSuperuser: fk.Bool(), IsSuperuser: fk.Bool(),
GroupID: tGroup.ID, DefaultGroupID: tGroup.ID,
} }
} }
@@ -87,7 +88,8 @@ func TestUserRepo_GetAll(t *testing.T) {
assert.Equal(t, usr.Email, usr2.Email) assert.Equal(t, usr.Email, usr2.Email)
// Check groups are loaded // Check groups are loaded
assert.NotNil(t, usr2.GroupID) assert.NotEqual(t, uuid.Nil, usr2.DefaultGroupID)
assert.Greater(t, len(usr2.GroupIDs), 0)
} }
} }
} }