@@ -26,24 +26,24 @@ fn generate_session_id() -> String {
26
26
use std:: collections:: hash_map:: DefaultHasher ;
27
27
use std:: hash:: { Hash , Hasher } ;
28
28
use std:: time:: { SystemTime , UNIX_EPOCH } ;
29
-
29
+
30
30
let timestamp = SystemTime :: now ( )
31
31
. duration_since ( UNIX_EPOCH )
32
32
. unwrap ( )
33
33
. as_nanos ( ) ;
34
34
let pid = std:: process:: id ( ) ;
35
35
let thread_id = std:: thread:: current ( ) . id ( ) ;
36
-
36
+
37
37
let mut hasher = DefaultHasher :: new ( ) ;
38
38
timestamp. hash ( & mut hasher) ;
39
39
pid. hash ( & mut hasher) ;
40
40
thread_id. hash ( & mut hasher) ;
41
-
41
+
42
42
let stack_var = 42u64 ;
43
43
( & stack_var as * const u64 as usize ) . hash ( & mut hasher) ;
44
-
44
+
45
45
let hash = hasher. finish ( ) ;
46
-
46
+
47
47
format ! ( "{:016x}_{}" , hash, pid)
48
48
}
49
49
@@ -57,9 +57,9 @@ pub struct ShaiPtyManager {
57
57
impl ShaiPtyManager {
58
58
pub fn new ( ) -> Result < Self , Box < dyn std:: error:: Error > > {
59
59
let ( master_fd, slave_fd) = Self :: create_pty_pair ( ) ?;
60
- Ok ( Self {
61
- master_fd,
62
- slave_fd,
60
+ Ok ( Self {
61
+ master_fd,
62
+ slave_fd,
63
63
session_id : generate_session_id ( ) ,
64
64
temp_rc_file : None ,
65
65
} )
@@ -94,11 +94,11 @@ impl ShaiPtyManager {
94
94
std:: ffi:: CStr :: from_ptr ( ptr) . to_string_lossy ( ) . into_owned ( )
95
95
} ;
96
96
97
- let slave_fd = unsafe {
97
+ let slave_fd = unsafe {
98
98
libc:: open (
99
- slave_name. as_ptr ( ) as * const i8 ,
99
+ slave_name. as_ptr ( ) as * const i8 ,
100
100
libc:: O_RDWR | libc:: O_NOCTTY
101
- )
101
+ )
102
102
} ;
103
103
if slave_fd == -1 {
104
104
unsafe { libc:: close ( master_fd) } ;
@@ -118,45 +118,45 @@ impl ShaiPtyManager {
118
118
self . setup_window_resize_handler ( ) ?;
119
119
120
120
let pid = unsafe { libc:: fork ( ) } ;
121
-
121
+
122
122
if pid == 0 {
123
- // CHILD: Become the shell
124
- self . setup_child_process ( shell, quiet) ;
123
+ // CHILD: Become the shell
124
+ self . setup_child_process ( shell, quiet) ;
125
125
} else if pid > 0 {
126
126
// PARENT: Handle I/O and run buffer server
127
127
unsafe { libc:: close ( self . slave_fd ) } ;
128
-
129
- let io_server = ShaiSessionServer :: new ( & self . session_id , 100 , 1000 ) ;
128
+
129
+ let io_server = ShaiSessionServer :: new ( & self . session_id , 100 , 1000 ) ;
130
130
io_server. start ( ) ?;
131
131
132
132
self . inject_shai_hooks ( & shell) ?;
133
133
134
134
self . handle_io_forwarding ( io_server, pid) ?;
135
-
135
+
136
136
MASTER_FD . store ( -1 , Ordering :: Relaxed ) ;
137
137
unsafe { libc:: close ( self . master_fd ) } ;
138
138
terminal. restore ( ) ;
139
139
} else {
140
140
// FORK FAILED
141
- unsafe {
141
+ unsafe {
142
142
libc:: close ( self . master_fd ) ;
143
143
libc:: close ( self . slave_fd ) ;
144
144
} ;
145
145
MASTER_FD . store ( -1 , Ordering :: Relaxed ) ;
146
146
terminal. restore ( ) ;
147
147
return Err ( "Fork failed" . into ( ) ) ;
148
148
}
149
-
149
+
150
150
Ok ( ( ) )
151
151
}
152
152
153
153
fn inject_shai_hooks ( & mut self , shell : & Shell ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
154
154
// Create temp file with RC content
155
155
let mut temp_file = NamedTempFile :: new ( ) ?;
156
156
temp_file. write_all ( shell. generate_rc_content ( ) . as_bytes ( ) ) ?;
157
- temp_file. flush ( ) ?;
157
+ temp_file. flush ( ) ?;
158
158
let temp_path = temp_file. path ( ) . to_string_lossy ( ) ;
159
-
159
+
160
160
// create source cmd
161
161
let source_cmd = match shell. shell_type {
162
162
ShellType :: Bash | ShellType :: Sh | ShellType :: Zsh | ShellType :: Fish => {
@@ -166,20 +166,20 @@ impl ShaiPtyManager {
166
166
format ! ( ". '{}'\n " , temp_path)
167
167
}
168
168
} ;
169
-
169
+
170
170
// Send source command to shell via stdin
171
171
let bytes_written = unsafe {
172
172
libc:: write (
173
- self . master_fd ,
174
- source_cmd. as_ptr ( ) as * const libc:: c_void ,
173
+ self . master_fd ,
174
+ source_cmd. as_ptr ( ) as * const libc:: c_void ,
175
175
source_cmd. len ( )
176
176
)
177
177
} ;
178
-
178
+
179
179
if bytes_written == -1 {
180
180
return Err ( "Failed to inject shai hooks" . into ( ) ) ;
181
181
}
182
-
182
+
183
183
let ( _file, kept_path) = temp_file. keep ( ) ?;
184
184
self . temp_rc_file = Some ( kept_path) ;
185
185
@@ -192,44 +192,44 @@ impl ShaiPtyManager {
192
192
sa. sa_sigaction = handle_sigwinch as usize ;
193
193
libc:: sigemptyset ( & mut sa. sa_mask ) ;
194
194
sa. sa_flags = libc:: SA_RESTART ;
195
-
195
+
196
196
libc:: sigaction ( libc:: SIGWINCH , & sa, std:: ptr:: null_mut ( ) ) ;
197
197
}
198
-
198
+
199
199
Ok ( ( ) )
200
200
}
201
201
202
202
fn setup_child_process ( & self , shell : Shell , quiet : bool ) -> ! {
203
203
unsafe {
204
204
libc:: close ( self . master_fd ) ;
205
-
205
+
206
206
// Set SHAI_SESSION_ID environment variable
207
207
let session_env = std:: ffi:: CString :: new ( "SHAI_SESSION_ID" ) . unwrap ( ) ;
208
208
let session_value = std:: ffi:: CString :: new ( self . session_id . as_str ( ) ) . unwrap ( ) ;
209
209
libc:: setenv ( session_env. as_ptr ( ) , session_value. as_ptr ( ) , 1 ) ;
210
-
210
+
211
211
if quiet {
212
212
let tmux_env = std:: ffi:: CString :: new ( "TMUX" ) . unwrap ( ) ;
213
213
libc:: unsetenv ( tmux_env. as_ptr ( ) ) ;
214
-
214
+
215
215
let term_session_env = std:: ffi:: CString :: new ( "TERM_SESSION_ID" ) . unwrap ( ) ;
216
216
libc:: unsetenv ( term_session_env. as_ptr ( ) ) ;
217
217
}
218
-
218
+
219
219
libc:: setsid ( ) ;
220
-
221
- if libc:: ioctl ( self . slave_fd , libc:: TIOCSCTTY as libc :: c_ulong , 0 ) == -1 {
220
+
221
+ if libc:: ioctl ( self . slave_fd , libc:: TIOCSCTTY . try_into ( ) . unwrap ( ) , 0 ) == -1 {
222
222
libc:: exit ( 1 ) ;
223
223
}
224
-
224
+
225
225
libc:: dup2 ( self . slave_fd , libc:: STDIN_FILENO ) ;
226
226
libc:: dup2 ( self . slave_fd , libc:: STDOUT_FILENO ) ;
227
227
libc:: dup2 ( self . slave_fd , libc:: STDERR_FILENO ) ;
228
-
228
+
229
229
if self . slave_fd > 2 {
230
230
libc:: close ( self . slave_fd ) ;
231
231
}
232
-
232
+
233
233
let shell_cstr = std:: ffi:: CString :: new ( shell. path ) . unwrap ( ) ;
234
234
let interactive_arg = std:: ffi:: CString :: new ( "-i" ) . unwrap ( ) ;
235
235
libc:: execl ( shell_cstr. as_ptr ( ) , shell_cstr. as_ptr ( ) , interactive_arg. as_ptr ( ) , std:: ptr:: null :: < i8 > ( ) ) ;
@@ -239,18 +239,18 @@ impl ShaiPtyManager {
239
239
240
240
fn handle_io_forwarding ( & self , io_server : ShaiSessionServer , child_pid : i32 ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
241
241
let master_fd_clone = self . master_fd ;
242
-
242
+
243
243
// loop to handle user input and send it to shell stdin
244
244
let _stdin_thread = thread:: spawn ( move || {
245
245
let mut stdin = io:: stdin ( ) ;
246
246
let mut buffer = [ 0u8 ; 1024 ] ;
247
-
247
+
248
248
loop {
249
249
match stdin. read ( & mut buffer) {
250
250
Ok ( 0 ) => break , // EOF
251
251
Ok ( n) => {
252
252
let input = & buffer[ ..n] ;
253
-
253
+
254
254
if unsafe { libc:: write ( master_fd_clone, input. as_ptr ( ) as * const libc:: c_void , n) } == -1 {
255
255
break ;
256
256
}
@@ -263,7 +263,7 @@ impl ShaiPtyManager {
263
263
let mut stdout = io:: stdout ( ) ;
264
264
let mut buffer = [ 0u8 ; 1024 ] ;
265
265
266
-
266
+
267
267
// consume until MAGIC_COOKIE is read (this is to avoid ugly sourcing echo)
268
268
if self . temp_rc_file . is_some ( ) {
269
269
let cookie = format ! ( "{}" , MAGIC_COOKIE ) . into_bytes ( ) ;
@@ -276,29 +276,29 @@ impl ShaiPtyManager {
276
276
while unsafe { libc:: read ( self . master_fd , b. as_mut_ptr ( ) as * mut _ , 1 ) } > 0 {
277
277
if b[ 0 ] != b'\r' && b[ 0 ] == b'\n' {
278
278
break ;
279
- }
279
+ }
280
280
}
281
- }
281
+ }
282
282
283
283
// loop to handle shell stdout and print it to user tty
284
284
loop {
285
- let bytes_read = unsafe {
286
- libc:: read ( self . master_fd , buffer. as_mut_ptr ( ) as * mut libc:: c_void , buffer. len ( ) )
285
+ let bytes_read = unsafe {
286
+ libc:: read ( self . master_fd , buffer. as_mut_ptr ( ) as * mut libc:: c_void , buffer. len ( ) )
287
287
} ;
288
-
288
+
289
289
if bytes_read <= 0 { break ; }
290
-
290
+
291
291
let output_data = & buffer[ ..bytes_read as usize ] ;
292
292
io_server. add_output ( output_data) ;
293
-
293
+
294
294
if stdout. write_all ( output_data) . is_err ( ) { break ; }
295
295
stdout. flush ( ) . ok ( ) ;
296
296
}
297
297
298
298
299
299
// wait for both loop to end
300
300
let mut status = 0 ;
301
- unsafe {
301
+ unsafe {
302
302
libc:: waitpid ( child_pid, & mut status, 0 ) ;
303
303
} ;
304
304
@@ -340,12 +340,12 @@ mod tests {
340
340
#[ test]
341
341
fn test_pty_manager_creation ( ) {
342
342
let pty = ShaiPtyManager :: new ( ) . unwrap ( ) ;
343
-
343
+
344
344
// Should have valid file descriptors
345
345
assert ! ( pty. master_fd >= 0 ) ;
346
346
assert ! ( pty. slave_fd >= 0 ) ;
347
347
assert ! ( pty. master_fd != pty. slave_fd) ;
348
-
348
+
349
349
// Should have a session ID
350
350
assert ! ( !pty. get_session_id( ) . is_empty( ) ) ;
351
351
assert ! ( pty. get_session_id( ) . contains( "_" ) ) ; // timestamp_pid format
@@ -355,46 +355,46 @@ mod tests {
355
355
fn test_session_id_format ( ) {
356
356
let pty = ShaiPtyManager :: new ( ) . unwrap ( ) ;
357
357
let session_id = pty. get_session_id ( ) ;
358
-
358
+
359
359
// Should be in format: hash_pid
360
360
let parts: Vec < & str > = session_id. split ( '_' ) . collect ( ) ;
361
361
assert_eq ! ( parts. len( ) , 2 ) ;
362
-
362
+
363
363
// First part should be a hex hash (16 characters)
364
364
let hash_part = parts[ 0 ] ;
365
365
assert_eq ! ( hash_part. len( ) , 16 , "Hash should be 16 hex characters" ) ;
366
366
assert ! ( hash_part. chars( ) . all( |c| c. is_ascii_hexdigit( ) ) , "Hash should only contain hex digits" ) ;
367
-
367
+
368
368
// Second part should be a PID (number)
369
369
assert ! ( parts[ 1 ] . parse:: <u32 >( ) . is_ok( ) , "Second part should be a valid PID" ) ;
370
-
370
+
371
371
// Verify it matches current process PID
372
372
let expected_pid = std:: process:: id ( ) ;
373
373
let actual_pid: u32 = parts[ 1 ] . parse ( ) . unwrap ( ) ;
374
374
assert_eq ! ( actual_pid, expected_pid, "PID should match current process" ) ;
375
375
}
376
-
376
+
377
377
#[ test]
378
378
fn test_unique_session_ids ( ) {
379
379
let pty1 = ShaiPtyManager :: new ( ) . unwrap ( ) ;
380
380
thread:: sleep ( Duration :: from_millis ( 1 ) ) ; // Ensure different timestamp
381
381
let pty2 = ShaiPtyManager :: new ( ) . unwrap ( ) ;
382
-
382
+
383
383
assert_ne ! ( pty1. get_session_id( ) , pty2. get_session_id( ) ) ;
384
384
}
385
385
386
386
#[ test]
387
387
fn test_create_pty_pair ( ) {
388
388
let result = ShaiPtyManager :: create_pty_pair ( ) ;
389
389
assert ! ( result. is_ok( ) ) ;
390
-
390
+
391
391
let ( master_fd, slave_fd) = result. unwrap ( ) ;
392
-
392
+
393
393
// Valid file descriptors
394
394
assert ! ( master_fd >= 0 ) ;
395
395
assert ! ( slave_fd >= 0 ) ;
396
396
assert_ne ! ( master_fd, slave_fd) ;
397
-
397
+
398
398
// Clean up
399
399
unsafe {
400
400
libc:: close ( master_fd) ;
@@ -405,23 +405,23 @@ mod tests {
405
405
#[ test]
406
406
fn test_pty_pair_communication ( ) {
407
407
let ( master_fd, slave_fd) = ShaiPtyManager :: create_pty_pair ( ) . unwrap ( ) ;
408
-
408
+
409
409
// Write to master, should be readable from slave
410
410
let test_data = b"Hello PTY\n " ;
411
411
let bytes_written = unsafe {
412
412
libc:: write ( master_fd, test_data. as_ptr ( ) as * const libc:: c_void , test_data. len ( ) )
413
413
} ;
414
414
assert_eq ! ( bytes_written, test_data. len( ) as isize ) ;
415
-
415
+
416
416
// Read from slave
417
417
let mut buffer = [ 0u8 ; 64 ] ;
418
418
let bytes_read = unsafe {
419
419
libc:: read ( slave_fd, buffer. as_mut_ptr ( ) as * mut libc:: c_void , buffer. len ( ) )
420
420
} ;
421
-
421
+
422
422
assert ! ( bytes_read > 0 ) ;
423
423
assert_eq ! ( & buffer[ ..bytes_read as usize ] , test_data) ;
424
-
424
+
425
425
// Clean up
426
426
unsafe {
427
427
libc:: close ( master_fd) ;
@@ -432,21 +432,21 @@ mod tests {
432
432
#[ test]
433
433
fn test_multiple_pty_creation ( ) {
434
434
let mut ptys = Vec :: new ( ) ;
435
-
435
+
436
436
// Create multiple PTYs
437
437
for _ in 0 ..5 {
438
438
let pty = ShaiPtyManager :: new ( ) . unwrap ( ) ;
439
439
ptys. push ( pty) ;
440
440
}
441
-
441
+
442
442
// All should have unique session IDs
443
443
let session_ids: Vec < String > = ptys. iter ( ) . map ( |p| p. get_session_id ( ) . to_string ( ) ) . collect ( ) ;
444
444
let mut unique_ids = session_ids. clone ( ) ;
445
445
unique_ids. sort ( ) ;
446
446
unique_ids. dedup ( ) ;
447
-
447
+
448
448
assert_eq ! ( session_ids. len( ) , unique_ids. len( ) ) ;
449
-
449
+
450
450
for pty in & ptys {
451
451
assert ! ( pty. master_fd >= 0 ) ;
452
452
assert ! ( pty. slave_fd >= 0 ) ;
0 commit comments