diff --git a/backend/internal/data/repo/repo_items.go b/backend/internal/data/repo/repo_items.go index ed5efc2f..af9e9a9c 100644 --- a/backend/internal/data/repo/repo_items.go +++ b/backend/internal/data/repo/repo_items.go @@ -295,18 +295,14 @@ func (e *ItemsRepository) publishMutationEvent(gid uuid.UUID) { } func (e *ItemsRepository) getOne(ctx context.Context, where ...predicate.Entity) (ItemOut, error) { - q := e.db.Entity.Query().Where(where...) + q := e.db.Entity.Query().Where(where...).Where(entity.TypeEQ("item")) return mapItemOutErr(q. WithFields(). WithLabel(). - WithChildren(func(query *ent.EntityQuery) { - query.Where(entity.TypeEQ("location")) - }). + WithLocation(). WithGroup(). - WithParent(func(query *ent.EntityQuery) { - query.Where(entity.TypeEQ("item")) - }). + WithParent(). WithAttachments(). Only(ctx), ) @@ -315,7 +311,7 @@ func (e *ItemsRepository) getOne(ctx context.Context, where ...predicate.Entity) // GetOne returns a single item by ID. If the item does not exist, an error is returned. // See also: GetOneByGroup to ensure that the item belongs to a specific group. func (e *ItemsRepository) GetOne(ctx context.Context, id uuid.UUID) (ItemOut, error) { - return e.getOne(ctx, entity.ID(id), entity.TypeEQ("item")) + return e.getOne(ctx, entity.ID(id)) } func (e *ItemsRepository) CheckRef(ctx context.Context, gid uuid.UUID, ref string) (bool, error) { @@ -361,7 +357,6 @@ func (e *ItemsRepository) QueryByGroup(ctx context.Context, gid uuid.UUID, q Ite entity.ManufacturerContainsFold(q.Search), entity.NotesContainsFold(q.Search), ), - entity.TypeEQ("item"), ) } @@ -382,9 +377,9 @@ func (e *ItemsRepository) QueryByGroup(ctx context.Context, gid uuid.UUID, q Ite labelPredicates := make([]predicate.Entity, 0, len(q.LabelIDs)) for _, l := range q.LabelIDs { if !q.NegateLabels { - labelPredicates = append(labelPredicates, entity.HasLabelWith(label.ID(l)), entity.TypeEQ("item")) + labelPredicates = append(labelPredicates, entity.HasLabelWith(label.ID(l))) } else { - labelPredicates = append(labelPredicates, entity.Not(entity.HasLabelWith(label.ID(l))), entity.TypeEQ("item")) + labelPredicates = append(labelPredicates, entity.Not(entity.HasLabelWith(label.ID(l)))) } } if !q.NegateLabels { @@ -418,7 +413,7 @@ func (e *ItemsRepository) QueryByGroup(ctx context.Context, gid uuid.UUID, q Ite if len(q.LocationIDs) > 0 { locationPredicates := make([]predicate.Entity, 0, len(q.LocationIDs)) for _, l := range q.LocationIDs { - locationPredicates = append(locationPredicates, entity.HasParentWith(entity.ID(l), entity.TypeEQ("location"))) + locationPredicates = append(locationPredicates, entity.HasLocationWith(entity.ID(l), entity.TypeEQ("location"))) } andPredicates = append(andPredicates, entity.Or(locationPredicates...)) @@ -439,7 +434,7 @@ func (e *ItemsRepository) QueryByGroup(ctx context.Context, gid uuid.UUID, q Ite } if len(q.ParentItemIDs) > 0 { - andPredicates = append(andPredicates, entity.HasParentWith(entity.IDIn(q.ParentItemIDs...), entity.TypeEQ("item"))) + andPredicates = append(andPredicates, entity.HasParentWith(entity.IDIn(q.ParentItemIDs...))) } } @@ -466,9 +461,7 @@ func (e *ItemsRepository) QueryByGroup(ctx context.Context, gid uuid.UUID, q Ite qb = qb. WithLabel(). - WithParent(func(query *ent.EntityQuery) { - query.Where(entity.TypeEQ("location")) - }). + WithLocation(). WithAttachments(func(aq *ent.AttachmentQuery) { aq.Where( attachment.Primary(true), @@ -513,9 +506,7 @@ func (e *ItemsRepository) QueryByAssetID(ctx context.Context, gid uuid.UUID, ass items, err := mapItemsSummaryErr( qb.Order(ent.Asc(entity.FieldName)). WithLabel(). - WithParent(func(query *ent.EntityQuery) { - query.Where(entity.TypeEQ("location")) - }). + WithLocation(). All(ctx), ) if err != nil { @@ -535,9 +526,7 @@ func (e *ItemsRepository) GetAll(ctx context.Context, gid uuid.UUID) ([]ItemOut, return mapItemsOutErr(e.db.Entity.Query(). Where(entity.HasGroupWith(group.ID(gid)), entity.TypeEQ("item")). WithLabel(). - WithParent(func(query *ent.EntityQuery) { - query.Where(entity.TypeEQ("location")) - }). + WithLocation(). WithFields(). All(ctx)) } @@ -546,7 +535,6 @@ func (e *ItemsRepository) GetAllZeroAssetID(ctx context.Context, gid uuid.UUID) q := e.db.Entity.Query().Where( entity.HasGroupWith(group.ID(gid)), entity.AssetID(0), - entity.TypeEQ("item"), ).Order( ent.Asc(entity.FieldCreatedAt), ) @@ -557,7 +545,6 @@ func (e *ItemsRepository) GetAllZeroAssetID(ctx context.Context, gid uuid.UUID) func (e *ItemsRepository) GetHighestAssetID(ctx context.Context, gid uuid.UUID) (AssetID, error) { q := e.db.Entity.Query().Where( entity.HasGroupWith(group.ID(gid)), - entity.TypeEQ("item"), ).Order( ent.Desc(entity.FieldAssetID), ).Limit(1) @@ -577,7 +564,6 @@ func (e *ItemsRepository) SetAssetID(ctx context.Context, gid uuid.UUID, id uuid q := e.db.Entity.Update().Where( entity.HasGroupWith(group.ID(gid)), entity.ID(id), - entity.TypeEQ("item"), ) _, err := q.SetAssetID(int(assetID)).Save(ctx) @@ -591,6 +577,7 @@ func (e *ItemsRepository) Create(ctx context.Context, gid uuid.UUID, data ItemCr SetQuantity(data.Quantity). SetDescription(data.Description). SetGroupID(gid). + SetLocationID(data.LocationID). SetAssetID(int(data.AssetID)) if len(data.LabelIDs) > 0 { @@ -607,7 +594,7 @@ func (e *ItemsRepository) Create(ctx context.Context, gid uuid.UUID, data ItemCr } func (e *ItemsRepository) Delete(ctx context.Context, id uuid.UUID) error { - err := e.db.Item.DeleteOneID(id).Exec(ctx) + err := e.db.Entity.DeleteOneID(id).Exec(ctx) if err != nil { return err } @@ -617,11 +604,11 @@ func (e *ItemsRepository) Delete(ctx context.Context, id uuid.UUID) error { } func (e *ItemsRepository) DeleteByGroup(ctx context.Context, gid, id uuid.UUID) error { - _, err := e.db.Item. + _, err := e.db.Entity. Delete(). Where( - item.ID(id), - item.HasGroupWith(group.ID(gid)), + entity.ID(id), + entity.HasGroupWith(group.ID(gid)), ).Exec(ctx) if err != nil { return err @@ -632,7 +619,7 @@ func (e *ItemsRepository) DeleteByGroup(ctx context.Context, gid, id uuid.UUID) } func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data ItemUpdate) (ItemOut, error) { - q := e.db.Item.Update().Where(item.ID(data.ID), item.HasGroupWith(group.ID(gid))). + q := e.db.Entity.Update().Where(entity.ID(data.ID), entity.HasGroupWith(group.ID(gid))). SetName(data.Name). SetDescription(data.Description). SetLocationID(data.LocationID). @@ -654,9 +641,9 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data SetWarrantyDetails(data.WarrantyDetails). SetQuantity(data.Quantity). SetAssetID(int(data.AssetID)). - SetSyncChildItemsLocations(data.SyncChildItemsLocations) + SetSyncChildEntitiesLocations(data.SyncChildItemsLocations) - currentLabels, err := e.db.Item.Query().Where(item.ID(data.ID)).QueryLabel().All(ctx) + currentLabels, err := e.db.Entity.Query().Where(entity.ID(data.ID), entity.TypeEQ("item")).QueryLabel().All(ctx) if err != nil { return ItemOut{}, err } @@ -682,7 +669,7 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data } if data.SyncChildItemsLocations { - children, err := e.db.Item.Query().Where(item.ID(data.ID)).QueryChildren().All(ctx) + children, err := e.db.Entity.Query().Where(entity.ID(data.ID), entity.TypeEQ("item")).QueryChildren().All(ctx) if err != nil { return ItemOut{}, err } @@ -708,7 +695,7 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data return ItemOut{}, err } - fields, err := e.db.ItemField.Query().Where(itemfield.HasItemWith(item.ID(data.ID))).All(ctx) + fields, err := e.db.EntityField.Query().Where(entityfield.HasEntityWith(entity.ID(data.ID), entity.TypeEQ("item"))).All(ctx) if err != nil { return ItemOut{}, err } @@ -719,9 +706,9 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data for _, f := range data.Fields { if f.ID == uuid.Nil { // Create New Field - _, err = e.db.ItemField.Create(). - SetItemID(data.ID). - SetType(itemfield.Type(f.Type)). + _, err = e.db.EntityField.Create(). + SetEntityID(data.ID). + SetType(entityfield.Type(f.Type)). SetName(f.Name). SetTextValue(f.TextValue). SetNumberValue(f.NumberValue). @@ -733,12 +720,12 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data } } - opt := e.db.ItemField.Update(). + opt := e.db.EntityField.Update(). Where( - itemfield.ID(f.ID), - itemfield.HasItemWith(item.ID(data.ID)), + entityfield.ID(f.ID), + entityfield.HasEntityWith(entity.ID(data.ID), entity.TypeEQ("item")), ). - SetType(itemfield.Type(f.Type)). + SetType(entityfield.Type(f.Type)). SetName(f.Name). SetTextValue(f.TextValue). SetNumberValue(f.NumberValue). @@ -756,10 +743,10 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data // Delete Fields that are no longer present if fieldIds.Len() > 0 { - _, err = e.db.ItemField.Delete(). + _, err = e.db.EntityField.Delete(). Where( - itemfield.IDIn(fieldIds.Slice()...), - itemfield.HasItemWith(item.ID(data.ID)), + entityfield.IDIn(fieldIds.Slice()...), + entityfield.HasEntityWith(entity.ID(data.ID), entity.TypeEQ("item")), ).Exec(ctx) if err != nil { return ItemOut{}, err @@ -773,15 +760,16 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data func (e *ItemsRepository) GetAllZeroImportRef(ctx context.Context, gid uuid.UUID) ([]uuid.UUID, error) { var ids []uuid.UUID - err := e.db.Item.Query(). + err := e.db.Entity.Query(). Where( - item.HasGroupWith(group.ID(gid)), - item.Or( - item.ImportRefEQ(""), - item.ImportRefIsNil(), + entity.HasGroupWith(group.ID(gid)), + entity.TypeEQ("item"), + entity.Or( + entity.ImportRefEQ(""), + entity.ImportRefIsNil(), ), ). - Select(item.FieldID). + Select(entity.FieldID). Scan(ctx, &ids) if err != nil { return nil, err @@ -791,10 +779,11 @@ func (e *ItemsRepository) GetAllZeroImportRef(ctx context.Context, gid uuid.UUID } func (e *ItemsRepository) Patch(ctx context.Context, gid, id uuid.UUID, data ItemPatch) error { - q := e.db.Item.Update(). + q := e.db.Entity.Update(). Where( - item.ID(id), - item.HasGroupWith(group.ID(gid)), + entity.ID(id), + entity.HasGroupWith(group.ID(gid)), + entity.TypeEQ("item"), ) if data.ImportRef != nil { @@ -816,16 +805,16 @@ func (e *ItemsRepository) GetAllCustomFieldValues(ctx context.Context, gid uuid. var values []st - err := e.db.Item.Query(). + err := e.db.Entity.Query(). Where( - item.HasGroupWith(group.ID(gid)), + entity.HasGroupWith(group.ID(gid)), ). QueryFields(). Where( - itemfield.Name(name), + entityfield.Name(name), ). Unique(true). - Select(itemfield.FieldTextValue). + Select(entityfield.FieldTextValue). Scan(ctx, &values) if err != nil { return nil, fmt.Errorf("failed to get field values: %w", err) @@ -846,13 +835,14 @@ func (e *ItemsRepository) GetAllCustomFieldNames(ctx context.Context, gid uuid.U var fields []st - err := e.db.Item.Query(). + err := e.db.Entity.Query(). Where( - item.HasGroupWith(group.ID(gid)), + entity.TypeEQ("item"), + entity.HasGroupWith(group.ID(gid)), ). QueryFields(). Unique(true). - Select(itemfield.FieldName). + Select(entityfield.FieldName). Scan(ctx, &fields) if err != nil { return nil, fmt.Errorf("failed to get custom fields: %w", err) @@ -873,15 +863,16 @@ func (e *ItemsRepository) GetAllCustomFieldNames(ctx context.Context, gid uuid.U // frontend. This function is intended to be used as a one-time fix for existing databases and may be // removed in the future. func (e *ItemsRepository) ZeroOutTimeFields(ctx context.Context, gid uuid.UUID) (int, error) { - q := e.db.Item.Query().Where( - item.HasGroupWith(group.ID(gid)), - item.Or( - item.PurchaseTimeNotNil(), - item.PurchaseFromLT("0002-01-01"), - item.SoldTimeNotNil(), - item.SoldToLT("0002-01-01"), - item.WarrantyExpiresNotNil(), - item.WarrantyDetailsLT("0002-01-01"), + q := e.db.Entity.Query().Where( + entity.HasGroupWith(group.ID(gid)), + entity.TypeEQ("item"), + entity.Or( + entity.PurchaseTimeNotNil(), + entity.PurchaseFromLT("0002-01-01"), + entity.SoldTimeNotNil(), + entity.SoldToLT("0002-01-01"), + entity.WarrantyExpiresNotNil(), + entity.WarrantyDetailsLT("0002-01-01"), ), ) @@ -897,7 +888,7 @@ func (e *ItemsRepository) ZeroOutTimeFields(ctx context.Context, gid uuid.UUID) updated := 0 for _, i := range items { - updateQ := e.db.Item.Update().Where(item.ID(i.ID)) + updateQ := e.db.Entity.Update().Where(entity.ID(i.ID), entity.TypeEQ("item")) if !i.PurchaseTime.IsZero() { switch { @@ -945,10 +936,11 @@ func (e *ItemsRepository) ZeroOutTimeFields(ctx context.Context, gid uuid.UUID) func (e *ItemsRepository) SetPrimaryPhotos(ctx context.Context, gid uuid.UUID) (int, error) { // All items where there is no primary photo - itemIDs, err := e.db.Item.Query(). + itemIDs, err := e.db.Entity.Query(). Where( - item.HasGroupWith(group.ID(gid)), - item.HasAttachmentsWith( + entity.HasGroupWith(group.ID(gid)), + entity.TypeEQ("item"), + entity.HasAttachmentsWith( attachment.TypeEQ(attachment.TypePhoto), attachment.Not( attachment.And( @@ -968,7 +960,7 @@ func (e *ItemsRepository) SetPrimaryPhotos(ctx context.Context, gid uuid.UUID) ( // Find the first photo attachment a, err := e.db.Attachment.Query(). Where( - attachment.HasItemWith(item.ID(id)), + attachment.HasEntityWith(entity.ID(id), entity.TypeEQ("item")), attachment.TypeEQ(attachment.TypePhoto), attachment.Primary(false), ).