Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement percent operators for config files #49

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 99 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ func homedir() string {
user, err := osuser.Current()
if err == nil {
return user.HomeDir
} else {
return os.Getenv("HOME")
}
return os.Getenv("HOME")
}

func userConfigFinder() string {
Expand Down Expand Up @@ -333,6 +332,102 @@ type Config struct {
position Position
}

// %% A literal `%'.
// %C Shorthand for %l%h%p%r.
// %d Local user's home directory.
// %h The remote hostname.
// %i The local user ID.
// %L The local hostname.
// %l The local hostname, including the domain name.
// %n The original remote hostname, as given on the command line.
// %p The remote port.
// %r The remote username.
// %u The local username.
func (host *Host) percent(alias, val string) string {
var (
b bytes.Buffer
sawPercent bool
)
for _, c := range val {
if sawPercent {
sawPercent = false
switch c {
case 'd':
b.WriteString(homedir())
case 'h':
b.WriteString(host.tryKV(alias, "HostName"))
case 'i':
b.WriteString(fmt.Sprintf("%d", os.Getuid()))
case 'L':
if h, err := os.Hostname(); err != nil {
b.WriteString(fmt.Sprintf("%%!L(%v)", err))
} else {
b.WriteString(h)
}
case 'n':
b.WriteString(alias)
case 'p':
b.WriteString(host.tryKV(alias, "Port"))
case 'r':
b.WriteString(host.tryKV(alias, "User"))
case 'u':
b.WriteString(os.Getenv("USER"))
case '%':
b.WriteString("%")
default:
// In the event of a bad format char, fmt returns
// the mangled string and no error.
// It may be best to follow that practice, as
// it gives you a much better idea where things
// went wrong.
b.WriteString(`%!` + string(c))
}
continue
}
if c != '%' {
b.WriteByte(byte(c))
continue
}
sawPercent = true
}
if sawPercent {
b.WriteString("%!(NOVERB)")
}
return b.String()
}

func (host *Host) findKV(alias, key string) (string, error) {
lowerKey := strings.ToLower(key)
for _, node := range host.Nodes {
switch t := node.(type) {
case *Empty:
continue
case *KV:
// "keys are case insensitive" per the spec
lkey := strings.ToLower(t.Key)
if lkey == "match" {
panic("can't handle Match directives")
}
if lkey == lowerKey {
return t.Value, nil
}
case *Include:
val := t.Get(alias, key)
if val != "" {
return val, nil
}
default:
return "", fmt.Errorf("unknown Node type %v", t)
}
}
return "", fmt.Errorf("%v has no key %v", alias, key)
}

func (host *Host) tryKV(alias, key string) string {
v, _ := host.findKV(alias, key)
return v
}

// Get finds the first value in the configuration that matches the alias and
// contains key. Get returns the empty string if no value was found, or if the
// Config contains an invalid conditional Include value.
Expand Down Expand Up @@ -411,6 +506,7 @@ func (c Config) String() string {
return marshal(c).String()
}

// MarshalText implements Marshal
func (c Config) MarshalText() ([]byte, error) {
return marshal(c).Bytes(), nil
}
Expand Down Expand Up @@ -792,7 +888,7 @@ func init() {
func newConfig() *Config {
return &Config{
Hosts: []*Host{
&Host{
{
implicit: true,
Patterns: []*Pattern{matchAll},
Nodes: make([]Node, 0),
Expand Down
35 changes: 35 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,38 @@ func TestNoTrailingNewline(t *testing.T) {
t.Errorf("wrong port: got %q want 4242", port)
}
}

func TestPercent(t *testing.T) {
b := bytes.NewBufferString(`Host wap
HostName wap.example.org
Port 22
User root
KexAlgorithms diffie-hellman-group1-sha1
`)
cfg, err := Decode(b)
if err != nil {
t.Fatal(err)
}
host := cfg.Hosts[1]
t.Logf("cfg is %v, %d hosts, Hosts %v, host %v", cfg, len(cfg.Hosts), cfg.Hosts, host)
home := os.Getenv("HOME")
user := os.Getenv("USER")

for _, tt := range []struct {
in string
out string
}{
{"hi", "hi"},
{"%dhi", home + "hi"},
{"%uhi", user + "hi"},
{"%h.%n.%p.%r.%u", "wap.example.org.wap.22.root." + user},
{"%Z", "%!Z"},
{"%", "%!(NOVERB)"},
{"%d%", home + "%!(NOVERB)"},
} {
o := host.percent("wap", tt.in)
if o != tt.out {
t.Errorf("%q: got %q, want %q", tt.in, o, tt.out)
}
}
}