Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server.go: "/" for windows #571

Merged
merged 13 commits into from
Jan 3, 2025
8 changes: 8 additions & 0 deletions examples/go-sftp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ func main() {
var (
readOnly bool
debugStderr bool
winRoot bool
)

flag.BoolVar(&readOnly, "R", false, "read-only server")
flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
flag.BoolVar(&winRoot, "wr", false, "windows root")

flag.Parse()

debugStream := io.Discard
Expand Down Expand Up @@ -128,6 +131,11 @@ func main() {
fmt.Fprintf(debugStream, "Read write server\n")
}

if winRoot {
serverOptions = append(serverOptions, sftp.WindowsRootEnumeratesDrives())
fmt.Fprintf(debugStream, "Windows root enabled\n")
}

server, err := sftp.NewServer(
channel,
serverOptions...,
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ require (
github.com/kr/fs v0.1.0
github.com/stretchr/testify v1.8.0
golang.org/x/crypto v0.31.0
golang.org/x/sys v0.28.0 // indirect
)
36 changes: 29 additions & 7 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"io/ioutil"
"os"
"path/filepath"
Expand All @@ -21,6 +22,18 @@ const (
SftpServerWorkerCount = 8
)

type file interface {
Stat() (os.FileInfo, error)
ReadAt(b []byte, off int64) (int, error)
WriteAt(b []byte, off int64) (int, error)
Readdir(int) ([]os.FileInfo, error)
Name() string
Truncate(int64) error
Chmod(mode fs.FileMode) error
Chown(uid, gid int) error
Close() error
}

// Server is an SSH File Transfer Protocol (sftp) server.
// This is intended to provide the sftp subsystem to an ssh server daemon.
// This implementation currently supports most of sftp server protocol version 3,
Expand All @@ -30,14 +43,15 @@ type Server struct {
debugStream io.Writer
readOnly bool
pktMgr *packetManager
openFiles map[string]*os.File
openFiles map[string]file
openFilesLock sync.RWMutex
handleCount int
workDir string
winRoot bool
maxTxPacket uint32
}

func (svr *Server) nextHandle(f *os.File) string {
func (svr *Server) nextHandle(f file) string {
svr.openFilesLock.Lock()
defer svr.openFilesLock.Unlock()
svr.handleCount++
Expand All @@ -57,7 +71,7 @@ func (svr *Server) closeHandle(handle string) error {
return EBADF
}

func (svr *Server) getHandle(handle string) (*os.File, bool) {
func (svr *Server) getHandle(handle string) (file, bool) {
svr.openFilesLock.RLock()
defer svr.openFilesLock.RUnlock()
f, ok := svr.openFiles[handle]
Expand Down Expand Up @@ -86,7 +100,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
serverConn: svrConn,
debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File),
openFiles: make(map[string]file),
maxTxPacket: defaultMaxTxPacket,
}

Expand Down Expand Up @@ -118,6 +132,14 @@ func ReadOnly() ServerOption {
}
}

// WindowsRootEnumeratesDrives configures a Server to serve a virtual '/' for windows that lists all drives
func WindowsRootEnumeratesDrives() ServerOption {
return func(s *Server) error {
s.winRoot = true
return nil
}
}

// WithAllocator enable the allocator.
// After processing a packet we keep in memory the allocated slices
// and we reuse them for new packets.
Expand Down Expand Up @@ -215,7 +237,7 @@ func handlePacket(s *Server, p orderedRequest) error {
}
case *sshFxpLstatPacket:
// stat the requested file
info, err := os.Lstat(s.toLocalPath(p.Path))
info, err := s.lstat(s.toLocalPath(p.Path))
rpkt = &sshFxpStatResponse{
ID: p.ID,
info: info,
Expand Down Expand Up @@ -289,7 +311,7 @@ func handlePacket(s *Server, p orderedRequest) error {
case *sshFxpOpendirPacket:
lp := s.toLocalPath(p.Path)

if stat, err := os.Stat(lp); err != nil {
if stat, err := s.stat(lp); err != nil {
rpkt = statusFromError(p.ID, err)
} else if !stat.IsDir() {
rpkt = statusFromError(p.ID, &os.PathError{
Expand Down Expand Up @@ -493,7 +515,7 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
mode = fs.FileMode() & os.ModePerm
}

f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
f, err := svr.openfile(svr.toLocalPath(p.Path), osFlags, mode)
if err != nil {
return statusFromError(p.ID, err)
}
Expand Down
21 changes: 21 additions & 0 deletions server_posix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//go:build !windows
// +build !windows

package sftp

import (
"io/fs"
"os"
)

func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
return os.OpenFile(path, flag, mode)
}

func (s *Server) lstat(name string) (os.FileInfo, error) {
return os.Lstat(name)
}

func (s *Server) stat(name string) (os.FileInfo, error) {
return os.Stat(name)
}
156 changes: 155 additions & 1 deletion server_windows.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package sftp

import (
"fmt"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"time"

"golang.org/x/sys/windows"
)

func (s *Server) toLocalPath(p string) string {
Expand All @@ -12,7 +19,11 @@ func (s *Server) toLocalPath(p string) string {

lp := filepath.FromSlash(p)

if path.IsAbs(p) {
if path.IsAbs(p) { // starts with '/'
if len(p) == 1 && s.winRoot {
return `\\.\` // for openfile
}

tmp := lp
for len(tmp) > 0 && tmp[0] == '\\' {
tmp = tmp[1:]
Expand All @@ -33,7 +44,150 @@ func (s *Server) toLocalPath(p string) string {
// e.g. "/C:" to "C:\\"
return tmp
}

if s.winRoot {
// Make it so that "/Windows" is not found, and "/c:/Windows" has to be used
return `\\.\` + tmp
}
}

return lp
}

func bitsToDrives(bitmap uint32) []string {
var drive rune = 'a'
var drives []string

for bitmap != 0 && drive <= 'z' {
if bitmap&1 == 1 {
drives = append(drives, string(drive)+":")
}
drive++
bitmap >>= 1
}

return drives
}

func getDrives() ([]string, error) {
mask, err := windows.GetLogicalDrives()
if err != nil {
return nil, fmt.Errorf("GetLogicalDrives: %w", err)
}
return bitsToDrives(mask), nil
}

type driveInfo struct {
fs.FileInfo
name string
}

func (i *driveInfo) Name() string {
return i.name // since the Name() returned from a os.Stat("C:\\") is "\\"
}

type winRoot struct {
drives []string
}

func newWinRoot() (*winRoot, error) {
drives, err := getDrives()
if err != nil {
return nil, err
}
return &winRoot{
drives: drives,
}, nil
}

func (f *winRoot) Readdir(n int) ([]os.FileInfo, error) {
drives := f.drives
if n > 0 && len(drives) > n {
drives = drives[:n]
}
f.drives = f.drives[len(drives):]
if len(drives) == 0 {
return nil, io.EOF
}

var infos []os.FileInfo
for _, drive := range drives {
fi, err := os.Stat(drive + `\`)
if err != nil {
return nil, err
}

di := &driveInfo{
FileInfo: fi,
name: drive,
}
infos = append(infos, di)
}

return infos, nil
}

func (f *winRoot) Stat() (os.FileInfo, error) {
return rootFileInfo, nil
}
func (f *winRoot) ReadAt(b []byte, off int64) (int, error) {
return 0, os.ErrPermission
}
func (f *winRoot) WriteAt(b []byte, off int64) (int, error) {
return 0, os.ErrPermission
}
func (f *winRoot) Name() string {
return "/"
}
func (f *winRoot) Truncate(int64) error {
return os.ErrPermission
}
func (f *winRoot) Chmod(mode fs.FileMode) error {
return os.ErrPermission
}
func (f *winRoot) Chown(uid, gid int) error {
return os.ErrPermission
}
func (f *winRoot) Close() error {
f.drives = nil
return nil
powellnorma marked this conversation as resolved.
Show resolved Hide resolved
}

func (s *Server) openfile(path string, flag int, mode fs.FileMode) (file, error) {
if path == `\\.\` && s.winRoot {
return newWinRoot()
}
return os.OpenFile(path, flag, mode)
}

type winRootFileInfo struct {
name string
modTime time.Time
}

func (w *winRootFileInfo) Name() string { return w.name }
func (w *winRootFileInfo) Size() int64 { return 0 }
func (w *winRootFileInfo) Mode() fs.FileMode { return fs.ModeDir | 0555 } // read+execute for all
func (w *winRootFileInfo) ModTime() time.Time { return w.modTime }
func (w *winRootFileInfo) IsDir() bool { return true }
func (w *winRootFileInfo) Sys() interface{} { return nil }

// Create a new root FileInfo
var rootFileInfo = &winRootFileInfo{
name: "/",
modTime: time.Now(),
}

func (s *Server) lstat(name string) (os.FileInfo, error) {
if name == `\\.\` && s.winRoot {
return rootFileInfo, nil
}
return os.Lstat(name)
}

func (s *Server) stat(name string) (os.FileInfo, error) {
if name == `\\.\` && s.winRoot {
return rootFileInfo, nil
}
return os.Stat(name)
}
Loading