package kong import ( "errors" "fmt" "os" "reflect" "sort" "strconv" "strings" ) // Path records the nodes and parsed values from the current command-line. type Path struct { Parent *Node // One of these will be non-nil. App *Application Positional *Positional Flag *Flag Argument *Argument Command *Command // Flags added by this node. Flags []*Flag // True if this Path element was created as the result of a resolver. Resolved bool // Remaining tokens after this node remainder []Token } // Node returns the Node associated with this Path, or nil if Path is a non-Node. func (p *Path) Node() *Node { switch { case p.App != nil: return p.App.Node case p.Argument != nil: return p.Argument case p.Command != nil: return p.Command } return nil } // Visitable returns the Visitable for this path element. func (p *Path) Visitable() Visitable { switch { case p.App != nil: return p.App case p.Argument != nil: return p.Argument case p.Command != nil: return p.Command case p.Flag != nil: return p.Flag case p.Positional != nil: return p.Positional } return nil } // Remainder returns the remaining unparsed args after this Path element. func (p *Path) Remainder() []string { args := []string{} for _, token := range p.remainder { args = append(args, token.String()) } return args } // Context contains the current parse context. type Context struct { *Kong // A trace through parsed nodes. Path []*Path // Original command-line arguments. Args []string // Error that occurred during trace, if any. Error error values map[*Value]reflect.Value // Temporary values during tracing. bindings bindings resolvers []Resolver // Extra context-specific resolvers. scan *Scanner } // Trace path of "args" through the grammar tree. // // The returned Context will include a Path of all commands, arguments, positionals and flags. // // This just constructs a new trace. To fully apply the trace you must call Reset(), Resolve(), // Validate() and Apply(). func Trace(k *Kong, args []string) (*Context, error) { s := Scan(args...).AllowHyphenPrefixedParameters(k.allowHyphenated) c := &Context{ Kong: k, Args: args, Path: []*Path{ {App: k.Model, Flags: k.Model.Flags, remainder: s.PeekAll()}, }, values: map[*Value]reflect.Value{}, scan: s, bindings: bindings{}, } c.Error = c.trace(c.Model.Node) return c, nil } // Bind adds bindings to the Context. func (c *Context) Bind(args ...any) { c.bindings.add(args...) } // BindTo adds a binding to the Context. // // This will typically have to be called like so: // // BindTo(impl, (*MyInterface)(nil)) func (c *Context) BindTo(impl, iface any) { c.bindings.addTo(impl, iface) } // BindToProvider allows binding of provider functions. // // This is useful when the Run() function of different commands require different values that may // not all be initialisable from the main() function. // // "provider" must be a function with the signature func(...) (T, error) or func(...) T, // where ... will be recursively injected with bound values. func (c *Context) BindToProvider(provider any) error { return c.bindings.addProvider(provider, false /* singleton */) } // BindSingletonProvider allows binding of provider functions. // The provider will be called once and the result cached. // // "provider" must be a function with the signature func(...) (T, error) or func(...) T, // where ... will be recursively injected with bound values. func (c *Context) BindSingletonProvider(provider any) error { return c.bindings.addProvider(provider, true /* singleton */) } // Value returns the value for a particular path element. func (c *Context) Value(path *Path) reflect.Value { switch { case path.Positional != nil: return c.values[path.Positional] case path.Flag != nil: return c.values[path.Flag.Value] case path.Argument != nil: return c.values[path.Argument.Argument] } panic("can only retrieve value for flag, argument or positional") } // Selected command or argument. func (c *Context) Selected() *Node { var selected *Node for _, path := range c.Path { switch { case path.Command != nil: selected = path.Command case path.Argument != nil: selected = path.Argument } } return selected } // Empty returns true if there were no arguments provided. func (c *Context) Empty() bool { for _, path := range c.Path { if !path.Resolved && path.App == nil { return false } } return true } // Validate the current context. func (c *Context) Validate() error { //nolint: gocyclo err := Visit(c.Model, func(node Visitable, next Next) error { switch node := node.(type) { case *Value: ok := atLeastOneEnvSet(node.Tag.Envs) if node.Enum != "" && (!node.Required || node.HasDefault || (len(node.Tag.Envs) != 0 && ok)) { if err := checkEnum(node, node.Target); err != nil { return err } } case *Flag: ok := atLeastOneEnvSet(node.Tag.Envs) if node.Enum != "" && (!node.Required || node.HasDefault || (len(node.Tag.Envs) != 0 && ok)) { if err := checkEnum(node.Value, node.Target); err != nil { return err } } } return next(nil) }) if err != nil { return err } for _, el := range c.Path { var ( value reflect.Value desc string ) switch node := el.Visitable().(type) { case *Value: value = node.Target desc = node.ShortSummary() case *Flag: value = node.Target desc = node.ShortSummary() case *Application: value = node.Target desc = "" case *Node: value = node.Target desc = node.Path() } if validate := isValidatable(value); validate != nil { if err := validate.Validate(c); err != nil { if desc != "" { return fmt.Errorf("%s: %w", desc, err) } return err } } } for _, resolver := range c.combineResolvers() { if err := resolver.Validate(c.Model); err != nil { return err } } for _, path := range c.Path { var value *Value switch { case path.Flag != nil: value = path.Flag.Value case path.Positional != nil: value = path.Positional } if value != nil && value.Tag.Enum != "" { if err := checkEnum(value, value.Target); err != nil { return err } } if err := checkMissingFlags(path.Flags); err != nil { return err } } // Check the terminal node. node := c.Selected() if node == nil { node = c.Model.Node } // Find deepest positional argument so we can check if all required positionals have been provided. positionals := 0 for _, path := range c.Path { if path.Positional != nil { positionals = path.Positional.Position + 1 } } if err := checkMissingChildren(node); err != nil { return err } if err := checkMissingPositionals(positionals, node.Positional); err != nil { return err } if err := checkXorDuplicatedAndAndMissing(c.Path); err != nil { return err } if node.Type == ArgumentNode { value := node.Argument if value.Required && !value.Set { return fmt.Errorf("%s is required", node.Summary()) } } return nil } // Flags returns the accumulated available flags. func (c *Context) Flags() (flags []*Flag) { for _, trace := range c.Path { flags = append(flags, trace.Flags...) } return } // Command returns the full command path. func (c *Context) Command() string { command := []string{} for _, trace := range c.Path { switch { case trace.Positional != nil: command = append(command, "<"+trace.Positional.Name+">") case trace.Argument != nil: command = append(command, "<"+trace.Argument.Name+">") case trace.Command != nil: command = append(command, trace.Command.Name) } } return strings.Join(command, " ") } // AddResolver adds a context-specific resolver. // // This is most useful in the BeforeResolve() hook. func (c *Context) AddResolver(resolver Resolver) { c.resolvers = append(c.resolvers, resolver) } // FlagValue returns the set value of a flag if it was encountered and exists, or its default value. func (c *Context) FlagValue(flag *Flag) any { for _, trace := range c.Path { if trace.Flag == flag { v, ok := c.values[trace.Flag.Value] if !ok { break } return v.Interface() } } if flag.Target.IsValid() { return flag.Target.Interface() } return flag.DefaultValue.Interface() } // Reset recursively resets values to defaults (as specified in the grammar) or the zero value. func (c *Context) Reset() error { return Visit(c.Model.Node, func(node Visitable, next Next) error { if value, ok := node.(*Value); ok { return next(value.Reset()) } return next(nil) }) } func (c *Context) endParsing() { args := []string{} for { token := c.scan.Pop() if token.Type == EOLToken { break } args = append(args, token.String()) } // Note: tokens must be pushed in reverse order. for i := range args { c.scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken) } } //nolint:maintidx func (c *Context) trace(node *Node) (err error) { //nolint: gocyclo positional := 0 node.Active = true flags := []*Flag{} flagNode := node if node.DefaultCmd != nil && node.DefaultCmd.Tag.Default == "withargs" { // Add flags of the default command if the current node has one // and that default command allows args / flags without explicitly // naming the command on the CLI. flagNode = node.DefaultCmd } for _, group := range flagNode.AllFlags(false) { flags = append(flags, group...) } if node.Passthrough { c.endParsing() } for !c.scan.Peek().IsEOL() { token := c.scan.Peek() switch token.Type { case UntypedToken: switch v := token.Value.(type) { case string: switch { case v == "-": fallthrough default: //nolint c.scan.Pop() c.scan.PushTyped(token.Value, PositionalArgumentToken) // Indicates end of parsing. All remaining arguments are treated as positional arguments only. case v == "--": c.endParsing() // Pop the -- token unless the next positional argument accepts passthrough arguments. if !(positional < len(node.Positional) && node.Positional[positional].Passthrough) { c.scan.Pop() } // Long flag. case strings.HasPrefix(v, "--"): c.scan.Pop() // Parse it and push the tokens. parts := strings.SplitN(v[2:], "=", 2) if len(parts) > 1 { c.scan.PushTyped(parts[1], FlagValueToken) } c.scan.PushTyped(parts[0], FlagToken) // Short flag. case strings.HasPrefix(v, "-"): c.scan.Pop() // Note: tokens must be pushed in reverse order. if tail := v[2:]; tail != "" { c.scan.PushTyped(tail, ShortFlagTailToken) } c.scan.PushTyped(v[1:2], ShortFlagToken) } default: c.scan.Pop() c.scan.PushTyped(token.Value, PositionalArgumentToken) } case ShortFlagTailToken: c.scan.Pop() // Note: tokens must be pushed in reverse order. if tail := token.String()[1:]; tail != "" { c.scan.PushTyped(tail, ShortFlagTailToken) } c.scan.PushTyped(token.String()[0:1], ShortFlagToken) case FlagToken: if err := c.parseFlag(flags, token.String()); err != nil { if isUnknownFlagError(err) && positional < len(node.Positional) && node.Positional[positional].PassthroughMode == PassThroughModeAll { c.scan.Pop() c.scan.PushTyped(token.String(), PositionalArgumentToken) } else { return err } } case ShortFlagToken: if err := c.parseFlag(flags, token.String()); err != nil { if isUnknownFlagError(err) && positional < len(node.Positional) && node.Positional[positional].PassthroughMode == PassThroughModeAll { c.scan.Pop() c.scan.PushTyped(token.String(), PositionalArgumentToken) } else { return err } } case FlagValueToken: return fmt.Errorf("unexpected flag argument %q", token.Value) case PositionalArgumentToken: candidates := []string{} // Ensure we've consumed all positional arguments. if positional < len(node.Positional) { arg := node.Positional[positional] if arg.Passthrough { c.endParsing() } arg.Active = true err := arg.Parse(c.scan, c.getValue(arg)) if err != nil { return err } c.Path = append(c.Path, &Path{ Parent: node, Positional: arg, remainder: c.scan.PeekAll(), }) positional++ break } // Assign token value to a branch name if tagged as an alias // An alias will be ignored in the case of an existing command cmds := make(map[string]bool) for _, branch := range node.Children { if branch.Type == CommandNode { cmds[branch.Name] = true } } for _, branch := range node.Children { for _, a := range branch.Aliases { _, ok := cmds[a] if token.Value == a && !ok { token.Value = branch.Name break } } } // After positional arguments have been consumed, check commands next... for _, branch := range node.Children { if branch.Type == CommandNode && !branch.Hidden { candidates = append(candidates, branch.Name) } if branch.Type == CommandNode && branch.Name == token.Value { c.scan.Pop() c.Path = append(c.Path, &Path{ Parent: node, Command: branch, Flags: branch.Flags, remainder: c.scan.PeekAll(), }) return c.trace(branch) } } // Finally, check arguments. for _, branch := range node.Children { if branch.Type == ArgumentNode { arg := branch.Argument if err := arg.Parse(c.scan, c.getValue(arg)); err == nil { c.Path = append(c.Path, &Path{ Parent: node, Argument: branch, Flags: branch.Flags, remainder: c.scan.PeekAll(), }) return c.trace(branch) } } } // If there is a default command that allows args and nothing else // matches, take the branch of the default command if node.DefaultCmd != nil && node.DefaultCmd.Tag.Default == "withargs" { c.Path = append(c.Path, &Path{ Parent: node, Command: node.DefaultCmd, Flags: node.DefaultCmd.Flags, remainder: c.scan.PeekAll(), }) return c.trace(node.DefaultCmd) } return findPotentialCandidates(token.String(), candidates, "unexpected argument %s", token) default: return fmt.Errorf("unexpected token %s", token) } } return c.maybeSelectDefault(flags, node) } // IgnoreDefault can be implemented by flags that want to be applied before any default commands. type IgnoreDefault interface { IgnoreDefault() } // End of the line, check for a default command, but only if we're not displaying help, // otherwise we'd only ever display the help for the default command. func (c *Context) maybeSelectDefault(flags []*Flag, node *Node) error { for _, flag := range flags { if _, ok := flag.Target.Interface().(IgnoreDefault); ok && flag.Set { return nil } } if node.DefaultCmd != nil { c.Path = append(c.Path, &Path{ Parent: node.DefaultCmd, Command: node.DefaultCmd, Flags: node.DefaultCmd.Flags, remainder: c.scan.PeekAll(), }) } return nil } // Resolve walks through the traced path, applying resolvers to any unset flags. func (c *Context) Resolve() error { resolvers := c.combineResolvers() if len(resolvers) == 0 { return nil } inserted := []*Path{} for _, path := range c.Path { for _, flag := range path.Flags { // Flag has already been set on the command-line. if _, ok := c.values[flag.Value]; ok { continue } // Pick the last resolved value. var selected any for _, resolver := range resolvers { s, err := resolver.Resolve(c, path, flag) if err != nil { return fmt.Errorf("%s: %w", flag.ShortSummary(), err) } if s == nil { continue } selected = s } if selected == nil { continue } scan := Scan().PushTyped(selected, FlagValueToken) delete(c.values, flag.Value) err := flag.Parse(scan, c.getValue(flag.Value)) if err != nil { return err } inserted = append(inserted, &Path{ Flag: flag, Resolved: true, remainder: c.scan.PeekAll(), }) } } c.Path = append(c.Path, inserted...) return nil } // Combine application-level resolvers and context resolvers. func (c *Context) combineResolvers() []Resolver { resolvers := []Resolver{} resolvers = append(resolvers, c.Kong.resolvers...) resolvers = append(resolvers, c.resolvers...) return resolvers } func (c *Context) getValue(value *Value) reflect.Value { v, ok := c.values[value] if !ok { v = reflect.New(value.Target.Type()).Elem() switch v.Kind() { case reflect.Ptr: v.Set(reflect.New(v.Type().Elem())) case reflect.Slice: v.Set(reflect.MakeSlice(v.Type(), 0, 0)) case reflect.Map: v.Set(reflect.MakeMap(v.Type())) default: } c.values[value] = v } return v } // ApplyDefaults if they are not already set. func (c *Context) ApplyDefaults() error { return Visit(c.Model.Node, func(node Visitable, next Next) error { var value *Value switch node := node.(type) { case *Flag: value = node.Value case *Node: value = node.Argument case *Value: value = node default: } if value != nil { if err := value.ApplyDefault(); err != nil { return err } } return next(nil) }) } // Apply traced context to the target grammar. func (c *Context) Apply() (string, error) { path := []string{} for _, trace := range c.Path { var value *Value switch { case trace.App != nil: case trace.Argument != nil: path = append(path, "<"+trace.Argument.Name+">") value = trace.Argument.Argument case trace.Command != nil: path = append(path, trace.Command.Name) case trace.Flag != nil: value = trace.Flag.Value case trace.Positional != nil: path = append(path, "<"+trace.Positional.Name+">") value = trace.Positional default: panic("unsupported path ?!") } if value != nil { value.Apply(c.getValue(value)) } } return strings.Join(path, " "), nil } func flipBoolValue(value reflect.Value) error { if value.Kind() == reflect.Bool { value.SetBool(!value.Bool()) return nil } if value.Kind() == reflect.Ptr { if !value.IsNil() { return flipBoolValue(value.Elem()) } return nil } return fmt.Errorf("cannot negate a value of %s", value.Type().String()) } func (c *Context) parseFlag(flags []*Flag, match string) (err error) { candidates := []string{} for _, flag := range flags { long := "--" + flag.Name matched := long == match candidates = append(candidates, long) if flag.Short != 0 { short := "-" + string(flag.Short) matched = matched || (short == match) candidates = append(candidates, short) } for _, alias := range flag.Aliases { alias = "--" + alias matched = matched || (alias == match) candidates = append(candidates, alias) } neg := negatableFlagName(flag.Name, flag.Tag.Negatable) if !matched && match != neg { continue } // Found a matching flag. c.scan.Pop() if match == neg && flag.Tag.Negatable != "" { flag.Negated = true } err := flag.Parse(c.scan, c.getValue(flag.Value)) if err != nil { var expected *expectedError if errors.As(err, &expected) && expected.token.InferredType().IsAny(FlagToken, ShortFlagToken) { return fmt.Errorf("%s; perhaps try %s=%q?", err.Error(), flag.ShortSummary(), expected.token) } return err } if flag.Negated { value := c.getValue(flag.Value) err := flipBoolValue(value) if err != nil { return err } flag.Value.Apply(value) } c.Path = append(c.Path, &Path{ Flag: flag, remainder: c.scan.PeekAll(), }) return nil } return &unknownFlagError{Cause: findPotentialCandidates(match, candidates, "unknown flag %s", match)} } func isUnknownFlagError(err error) bool { var unknown *unknownFlagError return errors.As(err, &unknown) } type unknownFlagError struct{ Cause error } func (e *unknownFlagError) Unwrap() error { return e.Cause } func (e *unknownFlagError) Error() string { return e.Cause.Error() } // Call an arbitrary function filling arguments with bound values. func (c *Context) Call(fn any, binds ...any) (out []any, err error) { fv := reflect.ValueOf(fn) bindings := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings) return callAnyFunction(fv, bindings) } // RunNode calls the Run() method on an arbitrary node. // // This is useful in conjunction with Visit(), for dynamically running commands. // // Any passed values will be bindable to arguments of the target Run() method. Additionally, // all parent nodes in the command structure will be bound. func (c *Context) RunNode(node *Node, binds ...any) (err error) { type targetMethod struct { node *Node method reflect.Value binds bindings } methodBinds := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings) methods := []targetMethod{} for i := 0; node != nil; i, node = i+1, node.Parent { method := getMethod(node.Target, "Run") methodBinds = methodBinds.clone() for p := node; p != nil; p = p.Parent { methodBinds = methodBinds.add(p.Target.Addr().Interface()) // Try value and pointer to value. for _, p := range []reflect.Value{p.Target, p.Target.Addr()} { t := p.Type() for i := 0; i < p.NumMethod(); i++ { methodt := t.Method(i) if strings.HasPrefix(methodt.Name, "Provide") { method := p.Method(i) if err := methodBinds.addProvider(method.Interface(), false /* singleton */); err != nil { return fmt.Errorf("%s.%s: %w", t.Name(), methodt.Name, err) } } } } } if method.IsValid() { methods = append(methods, targetMethod{node, method, methodBinds}) } } if len(methods) == 0 { return fmt.Errorf("no Run() method found in hierarchy of %s", c.Selected().Summary()) } for _, method := range methods { if err = callFunction(method.method, method.binds); err != nil { return err } } return nil } // Run executes the Run() method on the selected command, which must exist. // // Any passed values will be bindable to arguments of the target Run() method. Additionally, // all parent nodes in the command structure will be bound. func (c *Context) Run(binds ...any) (err error) { node := c.Selected() if node == nil { if len(c.Path) == 0 { return fmt.Errorf("no command selected") } selected := c.Path[0].Node() if selected.Type == ApplicationNode { method := getMethod(selected.Target, "Run") if method.IsValid() { node = selected } } if node == nil { return fmt.Errorf("no command selected") } } runErr := c.RunNode(node, binds...) err = c.Kong.applyHook(c, "AfterRun") return errors.Join(runErr, err) } // PrintUsage to Kong's stdout. // // If summary is true, a summarised version of the help will be output. func (c *Context) PrintUsage(summary bool) error { options := c.helpOptions options.Summary = summary return c.help(options, c) } func checkMissingFlags(flags []*Flag) error { xorGroupSet := map[string]bool{} xorGroup := map[string][]string{} andGroupSet := map[string]bool{} andGroup := map[string][]string{} missing := []string{} andGroupRequired := getRequiredAndGroupMap(flags) for _, flag := range flags { for _, and := range flag.And { flag.Required = andGroupRequired[and] } if flag.Set { for _, xor := range flag.Xor { xorGroupSet[xor] = true } for _, and := range flag.And { andGroupSet[and] = true } } if !flag.Required || flag.Set { continue } if len(flag.Xor) > 0 || len(flag.And) > 0 { for _, xor := range flag.Xor { if xorGroupSet[xor] { continue } xorGroup[xor] = append(xorGroup[xor], flag.Summary()) } for _, and := range flag.And { andGroup[and] = append(andGroup[and], flag.Summary()) } } else { missing = append(missing, flag.Summary()) } } for xor, flags := range xorGroup { if !xorGroupSet[xor] && len(flags) > 1 { missing = append(missing, strings.Join(flags, " or ")) } } for _, flags := range andGroup { if len(flags) > 1 { missing = append(missing, strings.Join(flags, " and ")) } } if len(missing) == 0 { return nil } sort.Strings(missing) return fmt.Errorf("missing flags: %s", strings.Join(missing, ", ")) } func getRequiredAndGroupMap(flags []*Flag) map[string]bool { andGroupRequired := map[string]bool{} for _, flag := range flags { for _, and := range flag.And { if flag.Required { andGroupRequired[and] = true } } } return andGroupRequired } func checkMissingChildren(node *Node) error { missing := []string{} missingArgs := []string{} for _, arg := range node.Positional { if arg.Required && !arg.Set { missingArgs = append(missingArgs, arg.Summary()) } } if len(missingArgs) > 0 { missing = append(missing, strconv.Quote(strings.Join(missingArgs, " "))) } for _, child := range node.Children { if child.Hidden { continue } if child.Argument != nil { if !child.Argument.Required { continue } missing = append(missing, strconv.Quote(child.Summary())) } else { missing = append(missing, strconv.Quote(child.Name)) } } if len(missing) == 0 { return nil } if len(missing) > 5 { missing = append(missing[:5], "...") } if len(missing) == 1 { return fmt.Errorf("expected %s", missing[0]) } return fmt.Errorf("expected one of %s", strings.Join(missing, ", ")) } // If we're missing any positionals and they're required, return an error. func checkMissingPositionals(positional int, values []*Value) error { // All the positionals are in. if positional >= len(values) { return nil } // We're low on supplied positionals, but the missing one is optional. if !values[positional].Required { return nil } missing := []string{} for ; positional < len(values); positional++ { arg := values[positional] // TODO(aat): Fix hardcoding of these env checks all over the place :\ if len(arg.Tag.Envs) != 0 { if atLeastOneEnvSet(arg.Tag.Envs) { continue } } missing = append(missing, "<"+arg.Name+">") } if len(missing) == 0 { return nil } return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " ")) } func checkEnum(value *Value, target reflect.Value) error { switch target.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < target.Len(); i++ { if err := checkEnum(value, target.Index(i)); err != nil { return err } } return nil case reflect.Map, reflect.Struct: return errors.New("enum can only be applied to a slice or value") case reflect.Ptr: if target.IsNil() { return nil } return checkEnum(value, target.Elem()) default: enumSlice := value.EnumSlice() v := fmt.Sprintf("%v", target) enums := []string{} for _, enum := range enumSlice { if enum == v { return nil } enums = append(enums, fmt.Sprintf("%q", enum)) } return fmt.Errorf("%s must be one of %s but got %q", value.ShortSummary(), strings.Join(enums, ","), fmt.Sprintf("%v", target.Interface())) } } func checkPassthroughArg(target reflect.Value) bool { typ := target.Type() switch typ.Kind() { case reflect.Slice: return typ.Elem().Kind() == reflect.String default: return false } } func checkXorDuplicatedAndAndMissing(paths []*Path) error { errs := []string{} if err := checkXorDuplicates(paths); err != nil { errs = append(errs, err.Error()) } if err := checkAndMissing(paths); err != nil { errs = append(errs, err.Error()) } if len(errs) > 0 { return errors.New(strings.Join(errs, ", ")) } return nil } func checkXorDuplicates(paths []*Path) error { for _, path := range paths { seen := map[string]*Flag{} for _, flag := range path.Flags { if !flag.Set { continue } for _, xor := range flag.Xor { if seen[xor] != nil { return fmt.Errorf("--%s and --%s can't be used together", seen[xor].Name, flag.Name) } seen[xor] = flag } } } return nil } func checkAndMissing(paths []*Path) error { for _, path := range paths { missingMsgs := []string{} andGroups := map[string][]*Flag{} for _, flag := range path.Flags { for _, and := range flag.And { andGroups[and] = append(andGroups[and], flag) } } for _, flags := range andGroups { oneSet := false notSet := []*Flag{} flagNames := []string{} for _, flag := range flags { flagNames = append(flagNames, flag.Name) if flag.Set { oneSet = true } else { notSet = append(notSet, flag) } } if len(notSet) > 0 && oneSet { missingMsgs = append(missingMsgs, fmt.Sprintf("--%s must be used together", strings.Join(flagNames, " and --"))) } } if len(missingMsgs) > 0 { return fmt.Errorf("%s", strings.Join(missingMsgs, ", ")) } } return nil } func findPotentialCandidates(needle string, haystack []string, format string, args ...any) error { if len(haystack) == 0 { return fmt.Errorf(format, args...) } closestCandidates := []string{} for _, candidate := range haystack { if strings.HasPrefix(candidate, needle) || levenshtein(candidate, needle) <= 2 { closestCandidates = append(closestCandidates, fmt.Sprintf("%q", candidate)) } } prefix := fmt.Sprintf(format, args...) if len(closestCandidates) == 1 { return fmt.Errorf("%s, did you mean %s?", prefix, closestCandidates[0]) } else if len(closestCandidates) > 1 { return fmt.Errorf("%s, did you mean one of %s?", prefix, strings.Join(closestCandidates, ", ")) } return fmt.Errorf("%s", prefix) } type validatable interface{ Validate() error } type extendedValidatable interface { Validate(kctx *Context) error } // Proxy a validatable function to the extendedValidatable interface type validatableFunc func() error func (f validatableFunc) Validate(kctx *Context) error { return f() } func isValidatable(v reflect.Value) extendedValidatable { if !v.IsValid() || (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice || v.Kind() == reflect.Map) && v.IsNil() { return nil } if validate, ok := v.Interface().(validatable); ok { return validatableFunc(validate.Validate) } if validate, ok := v.Interface().(extendedValidatable); ok { return validate } if v.CanAddr() { return isValidatable(v.Addr()) } return nil } func atLeastOneEnvSet(envs []string) bool { for _, env := range envs { if _, ok := os.LookupEnv(env); ok { return true } } return false }