golang使用flag添加命令行参数

golang使用flag添加命令行参数

代码:

// targs.go
package main

import (
    "log"
    "flag"
    "strings"
)

// 定义可重复参数
type strSlice []string

func (s *strSlice) Set(value string) error {
    *s = append(*s, value)
    return nil
}

func (s strSlice) String() string {
    return strings.Join(s, ",")
}

func main() {
    // 声明变量
    var ( 
        name = flag.String("name", "Lilei", "Give your name")
        age = flag.Int("age", 28, "Give your age")
        iscold = flag.Bool("iscold", false, "Are you cold")
        arglst strSlice
    )
    // 解析 arg 格式的参数
    flag.Var(&arglst, "arg", "some other arg")
    // 解析参数
    flag.Parse()

    // 输出测试
    log.Printf("name: [%s]", *name)
    log.Printf("age: [%d]", *age)
    log.Printf("iscold: [%t]", *iscold)
    log.Printf("arglst: [%v]", arglst)
}

测试:


扩展:

🌟 接口 flag.Value

flag.Var 用来注册一个“可变”的命令行标志(flag),它允许你使用自定义数据类型(比如 slice、map 等),只要这个类型实现了 flag.Value 接口。

Go 的标准库 flag.Value 接口定义:

type Value interface {
    String() string
    Set(string) error
}
方法作用
Set(string)当命令行传入 --xxx=value 时,Viper 调用这个方法来设置值
String()用于生成帮助信息(--help 输出)
var cmdargs strSlice
flag.Var(&cmdargs, "cmdarg", "cmd指令需要的参数")  // 来获取 --cmdarg 参数提供的数据

flag.Int flag.Float flag.String flag.Bool 返回的全部是指针
flag.IntVar flag.FloatVar flag.StringVar flag.BoolVar 传入变量地址来获取值。

demo:

package main

import (
    "flag"
    "fmt"
    "net"
    "strings"
    "time"
)

// 自定义类型定义
type stringSlice []string

func (s *stringSlice) Set(value string) error {
    *s = append(*s, value)
    return nil
}

func (s *stringSlice) String() string {
    return fmt.Sprintf("%v", []string(*s))
}

type LogLevel string

const (
    Debug LogLevel = "debug"
    Info  LogLevel = "info"
    Warn  LogLevel = "warn" 
    Error LogLevel = "error"
)

func (l *LogLevel) Set(value string) error {
    level := LogLevel(value)
    switch level {
    case Debug, Info, Warn, Error:
        *l = level
        return nil
    default:
        return fmt.Errorf("无效的日志级别: %s", value)
    }
}

func (l *LogLevel) String() string {
    return string(*l)
}

type IPList []net.IP

func (ips *IPList) Set(value string) error {
    ip := net.ParseIP(value)
    if ip == nil {
        return fmt.Errorf("无效的IP地址: %s", value)
    }
    *ips = append(*ips, ip)
    return nil
}

func (ips *IPList) String() string {
    ipsStr := make([]string, len(*ips))
    for i, ip := range *ips {
        ipsStr[i] = ip.String()
    }
    return strings.Join(ipsStr, ", ")
}

func main() {
    // 内置类型参数
    var (
        // 布尔类型
        debug = flag.Bool("debug", false, "启用调试模式")
        verbose = flag.Bool("verbose", false, "详细输出")
        
        // 整数类型 - 使用 Int 代替 Int32
        port = flag.Int("port", 8080, "服务端口")
        priority = flag.Int("priority", 1, "优先级")  // 改为 Int
        maxSize = flag.Int64("max-size", 1024, "最大尺寸(MB)")
        
        // 浮点数类型 - 使用 Float64 代替 Float32
        rate = flag.Float64("rate", 0.5, "处理速率")
        temperature = flag.Float64("temp", 25.5, "温度(℃)")  // 改为 Float64
        
        // 字符串类型
        name = flag.String("name", "World", "姓名")
        config = flag.String("config", "config.json", "配置文件路径")
        
        // 时间类型
        interval = flag.Duration("interval", 1*time.Second, "轮询间隔")
        timeout time.Duration
    )
    
    flag.DurationVar(&timeout, "timeout", 30*time.Second, "超时时间")
    
    // 自定义类型参数
    var (
        tags stringSlice
        logLevel LogLevel = Info
        allowedIPs IPList
    )
    
    flag.Var(&tags, "tag", "标签(可多次使用)")
    flag.Var(&logLevel, "log-level", "日志级别: debug, info, warn, error")
    flag.Var(&allowedIPs, "allow-ip", "允许的IP地址")
    
    // 解析参数
    flag.Parse()
    
    // 输出所有参数值
    fmt.Println("=== 参数解析结果 ===")
    fmt.Printf("调试模式: %v\n", *debug)
    fmt.Printf("详细输出: %v\n", *verbose)
    fmt.Printf("端口: %d\n", *port)
    fmt.Printf("优先级: %d\n", *priority)
    fmt.Printf("最大尺寸: %d MB\n", *maxSize)
    fmt.Printf("处理速率: %.2f\n", *rate)
    fmt.Printf("温度: %.1f℃\n", *temperature)
    fmt.Printf("姓名: %s\n", *name)
    fmt.Printf("配置文件: %s\n", *config)
    fmt.Printf("轮询间隔: %v\n", *interval)
    fmt.Printf("超时时间: %v\n", timeout)
    fmt.Printf("标签: %v\n", []string(tags))
    fmt.Printf("日志级别: %s\n", logLevel)
    fmt.Printf("允许的IP: %v\n", []net.IP(allowedIPs))
    
    // 显示帮助信息
    if len(flag.Args()) > 0 && flag.Args()[0] == "help" {
        flag.PrintDefaults()
    }
}

测试执行:

go run demo.go \
    --debug \
    --verbose \
    --port=9090 \
    --priority=5 \
    --max-size=2048 \
    --rate=0.8 \
    --temp=30.5 \
    --name="张三" \
    --config="/etc/app/config.yaml" \
    --interval=5s \
    --timeout=1m \
    --tag=api \
    --tag=database \
    --log-level=debug \
    --allow-ip=192.168.1.1 \
    --allow-ip=10.0.0.1

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注