Skip to content

Commit c01e7e3

Browse files
authored
Merge pull request #7 from Hill-98/add-options
Reimplement repeater.ps1
2 parents 2b4d8db + 91ed48f commit c01e7e3

File tree

6 files changed

+76
-120
lines changed

6 files changed

+76
-120
lines changed

config.go

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ import (
99
"log"
1010
"net"
1111
"os"
12+
"os/exec"
1213
"os/signal"
1314
"path/filepath"
14-
"strings"
1515
"syscall"
1616
)
1717

1818
type config struct {
1919
socketPath string
20+
powershellPath string
2021
foreground bool
2122
verbose bool
2223
stop bool
@@ -35,10 +36,26 @@ func defaultSocketPath() string {
3536
return filepath.Join(home, ".ssh", "wsl2-ssh-agent.sock")
3637
}
3738

39+
func powershellPath() string {
40+
path, err := exec.LookPath("powershell.exe")
41+
if err != nil {
42+
path := "/mnt/c/Windows/System32/WindowsPowerShell/v1.0/powershell.exe"
43+
_, err := os.Stat(path)
44+
if err == nil {
45+
return path
46+
} else {
47+
return ""
48+
}
49+
50+
}
51+
return path
52+
}
53+
3854
func newConfig() *config {
3955
c := &config{}
4056

4157
flag.StringVar(&c.socketPath, "socket", defaultSocketPath(), "a path of UNIX domain socket to listen")
58+
flag.StringVar(&c.powershellPath, "powershell-path", powershellPath(), "a path of Windows PowerShell (powershell.exe)")
4259
flag.BoolVar(&c.foreground, "foreground", false, "run in foreground mode")
4360
flag.BoolVar(&c.verbose, "verbose", false, "verbose mode")
4461
flag.StringVar(&c.logFile, "log", "", "a file path to write the log")
@@ -52,10 +69,15 @@ func newConfig() *config {
5269

5370
flag.Parse()
5471

72+
if c.powershellPath == "" {
73+
fmt.Printf("powershell.exe not found, use the -powershell-path to customize the path.\n")
74+
os.Exit(1)
75+
}
76+
5577
return c
5678
}
5779

58-
func (c *config) start() (context.Context, bool) {
80+
func (c *config) start() (context.Context) {
5981
if c.version {
6082
fmt.Printf("wsl2-ssh-agent %s\n", version)
6183
os.Exit(0)
@@ -110,13 +132,7 @@ func (c *config) start() (context.Context, bool) {
110132
signal.Ignore(syscall.SIGPIPE)
111133
ctx, _ := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
112134

113-
// check if ssh-agent.exe is older than 8.9
114-
ignoreOpenSSHExtensions := strings.Compare(getWinSshVersion(), "OpenSSH_for_Windows_8.9") == -1
115-
if ignoreOpenSSHExtensions {
116-
log.Printf("ssh-agent.exe seems to be old; ignore OpenSSH extension messages")
117-
}
118-
119-
return ctx, ignoreOpenSSHExtensions
135+
return ctx
120136
}
121137

122138
func (c *config) setupLogFile() {

main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ package main
33
func main() {
44
c := newConfig()
55

6-
ctx, ignoreOpenSSHExtensions := c.start()
6+
ctx := c.start()
77

8-
s := newServer(c.socketPath, ignoreOpenSSHExtensions)
8+
s := newServer(c.socketPath, c.powershellPath)
99

1010
s.run(ctx)
1111
}

repeater.go

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
package main
22

33
import (
4-
"bytes"
54
"context"
65
_ "embed"
76
"fmt"
87
"io"
98
"log"
109
"os/exec"
11-
"strings"
1210
"time"
1311
)
1412

@@ -28,11 +26,11 @@ var waitTimes []time.Duration = []time.Duration{
2826
}
2927

3028
// invoke PowerShell.exe and run
31-
func newRepeater(ctx context.Context) (*repeater, error) {
29+
func newRepeater(ctx context.Context, powershell string) (*repeater, error) {
3230
for i, limit := range waitTimes {
3331
log.Printf("invoking [W] in PowerShell.exe%s", trial(i))
3432

35-
cmd := exec.Command("PowerShell.exe", "-Command", "-")
33+
cmd := exec.Command(powershell, "-Command", "-")
3634
in, err := cmd.StdinPipe()
3735
if err != nil {
3836
continue
@@ -96,37 +94,6 @@ func (rep *repeater) terminate() {
9694
terminate(rep.cmd)
9795
}
9896

99-
func getWinSshVersion() string {
100-
for i, limit := range waitTimes {
101-
ctx, cancel := context.WithTimeout(context.Background(), limit)
102-
defer cancel()
103-
104-
log.Printf("check the version of ssh.exe%s", trial(i))
105-
106-
cmd := exec.CommandContext(ctx, "ssh.exe", "-V")
107-
108-
var stdout, stderr bytes.Buffer
109-
cmd.Stdout = &stdout
110-
cmd.Stderr = &stderr
111-
112-
err := cmd.Run()
113-
114-
if err != nil {
115-
log.Printf("failed to invoke ssh.exe: %s", err)
116-
continue
117-
}
118-
119-
version := strings.TrimSuffix(stderr.String(), "\r\n")
120-
121-
log.Printf("the version of ssh.exe: %#v", version)
122-
return version
123-
}
124-
125-
log.Printf("failed to check the version of ssh.exe")
126-
127-
return ""
128-
}
129-
13097
func trial(i int) string {
13198
if i == 0 {
13299
return ""

repeater.ps1

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,68 @@ Function Log($msg) {
33
$host.ui.WriteErrorLine("[W] $date $msg")
44
}
55

6-
Function RelayMessage($from, $to, $buf, $arrow) {
6+
Function ReadMessage($stream) {
7+
$buf = New-Object byte[] 4
78
$offset = 0
89
while ($offset -lt 4) {
9-
$n = $from.Read($buf, $offset, 4 - $offset);
10-
if ($n -eq 0) { exit }
10+
$n = $stream.Read($buf, $offset, 4 - $offset);
11+
if ($n -eq 0) {
12+
break
13+
}
1114
$offset += $n;
1215
}
13-
$len = (($buf[0] * 256 + $buf[1]) * 256 + $buf[2]) * 256 + $buf[3] + 4
14-
Log "[L] $arrow [W] $arrow ssh-agent.exe ($len B)"
15-
$len
16-
while ($offset -lt $len) {
17-
$n = $from.Read($buf, $offset, [Math]::Min($len, $buf.Length) - $offset)
18-
if ($n -eq 0) { exit }
19-
$offset += $n
20-
$to.Write($buf, 0, $offset)
21-
$len -= $offset
22-
$offset = 0
16+
if ($offset -eq 4) {
17+
$len = (($buf[0] * 256 + $buf[1]) * 256 + $buf[2]) * 256 + $buf[3] + 4
18+
[Array]::Resize([ref]$buf, $len)
19+
while ($offset -lt $buf.Length) {
20+
$n = $stream.Read($buf, $offset, $buf.Length - $offset)
21+
if ($n -eq 0) {
22+
break
23+
}
24+
$offset += $n
25+
}
2326
}
27+
[Array]::Resize([ref]$buf, $offset)
28+
return $buf
2429
}
2530

2631
Function MainLoop {
2732
Try {
28-
$buf = New-Object byte[] 8192
33+
$ignoreOpenSSHExtensions = $false
34+
Try {
35+
$sshAgentVersion = (Get-Command -CommandType Application ssh-agent.exe -ErrorAction Stop)[0].Version
36+
$ignoreOpenSSHExtensions = ($sshAgentVersion.Major -le 8 -and $sshAgentVersion.Minor -lt 9)
37+
Log "ssh-agent.exe version: $($sshAgentVersion.ToString()) (ignoreOpenSSHExtensions: $ignoreOpenSSHExtensions)"
38+
}
39+
Catch {
40+
$ignoreOpenSSHExtensions = $true
41+
}
42+
2943
$ssh_client_in = [console]::OpenStandardInput()
3044
$ssh_client_out = [console]::OpenStandardOutput()
3145

3246
$ver = $PSVersionTable["PSVersion"]
47+
$ssh_client_out.WriteByte(0xff)
3348
Log "ready: PSVersion $ver"
3449

35-
$buf[0] = 0xff
36-
$ssh_client_out.Write($buf, 0, 1)
37-
3850
while ($true) {
3951
Try {
4052
$null = $ssh_client_in.Read((New-Object byte[] 1), 0, 0)
53+
$buf = ReadMessage $ssh_client_in
54+
if ($ignoreOpenSSHExtensions -and $buf.Length -gt 4 -and $buf[4] -eq 0x1b) {
55+
$buf = [byte[]](0, 0, 0, 1, 6)
56+
$ssh_client_out.Write($buf, 0, $buf.Length)
57+
Log "[W] return dummy for OpenSSH ext."
58+
Continue
59+
}
4160
$ssh_agent = New-Object System.IO.Pipes.NamedPipeClientStream ".", "openssh-ssh-agent", InOut
4261
$ssh_agent.Connect()
4362
Log "[W] named pipe: connected"
44-
$len = RelayMessage $ssh_client_in $ssh_agent $buf "->"
45-
$len = RelayMessage $ssh_agent $ssh_client_out $buf "<-"
63+
$ssh_agent.Write($buf, 0, $buf.Length)
64+
Log "[L] -> [W] -> ssh-agent.exe ($($buf.Length) B)"
65+
$buf = ReadMessage $ssh_agent
66+
$ssh_client_out.Write($buf, 0, $buf.Length)
67+
Log "[L] <- [W] <- ssh-agent.exe ($($buf.Length) B)"
4668
}
4769
Finally {
4870
if ($null -ne $ssh_agent) {

repeater_test.go

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ loop do
2222
end
2323
`
2424

25-
const dummySsh = `#!/usr/bin/ruby
26-
$stderr << "Hello\r\n"
27-
`
28-
2925
func setupDummyEnv(t *testing.T) string {
3026
t.Helper()
3127
log.SetOutput(io.Discard)
@@ -93,26 +89,3 @@ func TestRepeaterNormal(t *testing.T) {
9389

9490
rep.terminate()
9591
}
96-
97-
func TestSshVersionNoSsh(t *testing.T) {
98-
setupDummyEnv(t)
99-
100-
s := getWinSshVersion()
101-
if s != "" {
102-
t.Errorf("getWinSshVersion should fail")
103-
}
104-
}
105-
106-
func TestSshVersionNormal(t *testing.T) {
107-
tmpDir := setupDummyEnv(t)
108-
109-
err := os.WriteFile(filepath.Join(tmpDir, "ssh.exe"), []byte(dummySsh), 0777)
110-
if err != nil {
111-
t.Fatal(err)
112-
}
113-
114-
s := getWinSshVersion()
115-
if s != "Hello" {
116-
t.Errorf("getWinSshVersion does not work well: %#v", s)
117-
}
118-
}

server.go

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@ import (
1313

1414
type server struct {
1515
listener net.Listener
16-
ignoreOpenSSHExtensions bool
16+
powershellPath string
1717
}
1818

19-
func newServer(path string, ignoreOpenSSHExtensions bool) *server {
20-
listener, err := net.Listen("unix", path)
19+
func newServer(socketPath string, powershellPath string) *server {
20+
listener, err := net.Listen("unix", socketPath)
2121
if err != nil {
2222
log.Fatal(err)
2323
}
24-
log.Printf("start listening on %s", path)
24+
log.Printf("start listening on %s", socketPath)
2525

26-
return &server{listener, ignoreOpenSSHExtensions}
26+
return &server{listener, powershellPath}
2727
}
2828

2929
type request struct {
@@ -91,7 +91,7 @@ func (s *server) server(ctx context.Context, cancel func(), requestQueue chan re
9191

9292
for {
9393
// invoke PowerShell.exe
94-
rep, err := newRepeater(ctx)
94+
rep, err := newRepeater(ctx, s.powershellPath)
9595
if err != nil {
9696
return
9797
}
@@ -180,22 +180,6 @@ func (s *server) client(wg *sync.WaitGroup, ctx context.Context, sshClient net.C
180180
}
181181
log.Printf("ssh -> [L] (%d B)", len(req))
182182

183-
if s.ignoreOpenSSHExtensions && req[4] == 0x1b /* SSH_AGENTC_EXTENSION */ {
184-
// This is OpenSSH's extension message since OpenSSH 8.9.
185-
//
186-
// ref: https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.agent
187-
//
188-
// If we pass this message to ssh-agent.exe in OpenSSH 8.6, the connection is closed.
189-
// So we need to drop this message and send a dummy success response.
190-
log.Printf("ssh <- [L] (5 B) <dummy for OpenSSH ext.>")
191-
err := replyDummySuccess(sshClient, 0)
192-
if err != nil {
193-
log.Printf("failed to write to ssh: %s", err)
194-
break
195-
}
196-
continue
197-
}
198-
199183
requestQueue <- request{data: req, resultChannel: resChan}
200184
resp, ok := <-resChan
201185
if !ok {
@@ -241,9 +225,3 @@ func readMessage(from io.Reader) ([]byte, error) {
241225

242226
return append(header, body...), nil
243227
}
244-
245-
func replyDummySuccess(client io.ReadWriter, len int64) error {
246-
buf := []byte{0, 0, 0, 1, 6 /* SSH_AGENT_SUCCESS */}
247-
_, err := client.Write(buf)
248-
return err
249-
}

0 commit comments

Comments
 (0)