2022-04-25 12:26:25 -05:00

382 lines
8.4 KiB
Go

package flags
import (
"fmt"
"strings"
)
type ErrorHandling uint
const INTERNAL_GROUP_NAME = "__internal__"
const (
ReturnOnError ErrorHandling = iota
PanicOnError
)
type UnknownHandling uint
const (
PassOnUnknown = iota
ErrorOnUnknown
)
type Set struct {
errorHandling ErrorHandling
flagMap map[string]*Flag
groups []*Group
name string
parsed bool
remaining []string
unknownFlags []string
unknownHandling UnknownHandling
}
type SetModifier func(s *Set)
type Visitor func(f *Flag)
func SetErrorMode(m ErrorHandling) SetModifier {
return func(s *Set) {
s.errorHandling = m
}
}
func SetUnknownMode(m UnknownHandling) SetModifier {
return func(s *Set) {
s.unknownHandling = m
}
}
func NewSet(name string, modifiers ...SetModifier) *Set {
s := &Set{
name: name,
groups: []*Group{},
errorHandling: ReturnOnError,
flagMap: map[string]*Flag{},
remaining: []string{},
unknownHandling: ErrorOnUnknown,
unknownFlags: []string{},
}
s.NewGroup(INTERNAL_GROUP_NAME, HideGroupName())
for _, m := range modifiers {
m(s)
}
return s
}
// Name of this flag set
func (s *Set) Name() string {
return s.name
}
// All defined groups within the set
func (s *Set) Groups() []*Group {
g := make([]*Group, len(s.groups))
copy(g, s.groups)
return g
}
// Visit flags that were updated either by CLI or
// environment variable
func (s *Set) Visit(fn Visitor) {
for _, f := range s.Flags() {
if f.Updated() {
fn(f)
}
}
}
// Visit flags that were set by the CLI only
func (s *Set) VisitCalled(fn Visitor) {
for _, f := range s.Flags() {
if f.Called() {
fn(f)
}
}
}
// Visit all flags
func (s *Set) VisitAll(fn Visitor) {
for _, f := range s.Flags() {
fn(f)
}
}
func (s *Set) AddGroup(g *Group) error {
// Check that group hasn't already been added
for _, cg := range s.groups {
if g == cg {
return fmt.Errorf("group already exists in set")
}
}
// Remove the group from its current set
idx := -1
for i, cg := range g.set.groups {
if cg == g {
idx = i
break
}
}
if idx >= 0 {
g.set.groups = append(g.set.groups[0:idx], g.set.groups[idx+1:]...)
}
// Update the groups Set and add the group to this Set's groups
g.set = s
s.groups = append(s.groups, g)
return nil
}
// Add a new group
func (s *Set) NewGroup(name string, modifiers ...GroupModifier) error {
for _, g := range s.groups {
if g.name == name {
return fmt.Errorf("flag group already exists with name %s", name)
}
}
newGroup(s, name, modifiers...)
return nil
}
// Default group for flags. The default group does
// not include a title section when displayed
func (s *Set) DefaultGroup() *Group {
if len(s.groups) < 1 {
panic("default group does not exist")
}
return s.groups[0]
}
// All defined flags within the set
func (s *Set) Flags() []*Flag {
f := []*Flag{}
for _, g := range s.groups {
f = append(f, g.flags...)
}
return f
}
func (s *Set) Flag(n string) (f *Flag, err error) {
if s.parsed {
f = s.flagMap[n]
} else {
for _, flg := range s.Flags() {
if flg.longName == n {
f = flg
}
}
}
if f == nil {
err = fmt.Errorf("failed to locate flag named: %s", n)
}
return
}
// Generate flag usage output
func (s *Set) Display() (o string) {
for _, g := range s.groups {
o += g.Display(11)
}
return
}
// Parse the command line options. The remaining arguments will
// be non-flag arguments, or if the unkown handling allows for
// passing, it will include unused flag/values and arguments.
func (s *Set) Parse(args []string) (remaining []string, err error) {
defer func() {
if err != nil && s.errorHandling == PanicOnError {
panic(err)
}
}()
// We only allow a set to parse once
if s.parsed {
return nil, fmt.Errorf("Set has already parsed arguments")
}
// Initialize all the flags. Errors returned from here
// are either flag initialization issues or flag name
// collisions
if err = s.initFlags(); err != nil {
return
}
// Now we start parsing
for i := 0; i < len(args); i++ {
w := args[i]
// If the argument is the `--` separator, we stop parsing
// and add the unprocessed arguments to the remaining list
// to be returned
if w == "--" {
if i+1 < len(args) {
// The remaining arguments may already be populated with
// previously encountered unknown flags if the set is
// configured for pass through. In that situation, we
// want to retain the `--` separater
if len(s.remaining) > 0 {
s.remaining = append(s.remaining, args[i:]...)
} else {
s.remaining = append(s.remaining, args[i+1:]...)
}
}
break
}
// Handle long name flag
if strings.HasPrefix(w, "--") {
var valueNext bool
var name, value string
flag := strings.Replace(w, "--", "", 1)
if strings.Contains(flag, "=") {
parts := strings.SplitN(flag, "=", 2)
name, value = parts[0], parts[1]
} else {
name = flag
valueNext = true
}
f, ok := s.flagMap[name]
// If the flag is not found check if we should error
if !ok {
if err = s.flagNotFound(name); err != nil {
return
}
// Since we haven't errored, add flag to remaining
s.remaining = append(s.remaining, w)
s.unknownFlags = append(s.unknownFlags, name)
continue
}
switch f.kind {
case BooleanType:
// Since boolean types can be negated, check the flag
// name to set the correct value
if strings.HasPrefix(name, "no-") {
value = "false"
} else {
value = "true"
}
case IncrementType:
// Increment values don't matter, so just set as 1
value = "1"
default:
// If the value was not included in the argument (argument form was not --flag=VAL)
// then we need to get the value from the next argument
if valueNext {
if i+1 >= len(args) {
return nil, fmt.Errorf("missing argument for flag `--%s`", f.longName)
}
i += 1
value = args[i]
}
}
// Mark the flag as being called on the CLI and the name used
f.markCalled(name)
// And finally, set the value
if err = f.setValue(value); err != nil {
return
}
} else if strings.HasPrefix(w, "-") {
// For short flags, multiple patterns can be used. Valid examples:
//
// Boolean/Increment types can be chained: -vvvbx (-v -v -v -b -x)
// Other types can include value in argument: -aVAL
// Chaining can be used for both: -vvvbxaVAL
wordLoop:
for j := 1; j < len(w); j++ {
c := string(w[j])
f, ok := s.flagMap[c]
// If the flag was not found check if we should error
if !ok {
if err = s.flagNotFound(c); err != nil {
return
}
// Add the unprocessed to remaining
s.remaining = append(s.remaining, "-"+w[j:])
// Only add the flag we encountered that was unknown
s.unknownFlags = append(s.unknownFlags, c)
continue
}
// Mark the flag as being called on the CLI and the name used
f.markCalled(c)
switch f.kind {
case BooleanType:
err = f.setValue("true")
case IncrementType:
err = f.setValue("1")
default:
// Check if we have anything left in this argument. If we do, it is the value.
// Otherwise, get the value from the next argument.
if len(w)-1 == j {
if i+1 >= len(args) {
return nil, fmt.Errorf("missing argument for flag `-%s", string(f.shortName))
}
i += 1
err = f.setValue(args[i])
} else {
err = f.setValue(w[j+1:])
}
if err != nil {
return
}
break wordLoop
}
// If an error was encountered, bail out
if err != nil {
return
}
}
} else {
s.remaining = append(s.remaining, w)
}
}
// TODO: need to validate for required flags
s.parsed = true
return s.remaining, nil
}
func (s *Set) flagNotFound(name string) error {
if s.unknownHandling == ErrorOnUnknown {
return fmt.Errorf("unknown flag encountered `%s`", name)
}
return nil
}
func (s *Set) initFlags() error {
for _, f := range s.Flags() {
if err := f.init(); err != nil {
return err
}
names := make([]string, len(f.aliases))
copy(names, f.aliases)
names = append(names, f.longName)
if f.shortName != 0 {
names = append(names, string(f.shortName))
}
for _, n := range names {
if cf, ok := s.flagMap[n]; ok {
var colFlag string
if len(n) == 1 {
colFlag = "-" + n
} else {
colFlag = "--" + n
}
return fmt.Errorf("flags --%s and --%s share a common flag (collision on %s)",
f.longName, cf.longName, colFlag)
}
s.flagMap[n] = f
}
}
return nil
}