@@ -3,87 +3,93 @@ package firewall
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "strconv"
6
7
)
7
8
8
- // SetAllowedPort sets a port to be allowed for incoming traffic.
9
- func (c * Config ) SetAllowedPort (ctx context.Context , port uint16 , remove bool ) (err error ) {
9
+ func (c * Config ) SetAllowedPort (ctx context.Context , port uint16 , intf string ) (err error ) {
10
10
c .stateMutex .Lock ()
11
11
defer c .stateMutex .Unlock ()
12
12
13
+ if port == 0 {
14
+ return nil
15
+ }
16
+
13
17
if ! c .enabled {
18
+ c .logger .Info ("firewall disabled, only updating allowed ports internal state" )
19
+ existingInterfaces , ok := c .allowedInputPorts [port ]
20
+ if ! ok {
21
+ existingInterfaces = make (map [string ]struct {})
22
+ }
23
+ existingInterfaces [intf ] = struct {}{}
24
+ c .allowedInputPorts [port ] = existingInterfaces
14
25
return nil
15
26
}
16
27
17
- if err := c .setAllowedPort (ctx , port , remove ); err != nil {
18
- return err
28
+ netInterfaces , has := c .allowedInputPorts [port ]
29
+ if ! has {
30
+ netInterfaces = make (map [string ]struct {})
31
+ } else if _ , exists := netInterfaces [intf ]; exists {
32
+ return nil
19
33
}
20
34
21
- // Apply custom rules after setting allowed port
22
- if err := c .runUserPostRules (ctx , c .customRulesPath , remove ); err != nil {
35
+ c .logger .Info ("setting allowed input port " + fmt .Sprint (port ) + " through interface " + intf + "..." )
36
+
37
+ const remove = false
38
+ if err := c .acceptInputToPort (ctx , intf , port , remove ); err != nil {
39
+ return fmt .Errorf ("allowing input to port %d through interface %s: %w" ,
40
+ port , intf , err )
41
+ }
42
+ netInterfaces [intf ] = struct {}{}
43
+ c .allowedInputPorts [port ] = netInterfaces
44
+
45
+ // Apply user-defined post firewall rules
46
+ if err := c .runUserPostRules (ctx , c .customRulesPath , false ); err != nil {
23
47
return fmt .Errorf ("running user defined post firewall rules: %w" , err )
24
48
}
25
49
26
50
return nil
27
51
}
28
52
29
- func (c * Config ) setAllowedPort (ctx context.Context , port uint16 , remove bool ) (err error ) {
53
+ func (c * Config ) RemoveAllowedPort (ctx context.Context , port uint16 ) (err error ) {
30
54
c .stateMutex .Lock ()
31
55
defer c .stateMutex .Unlock ()
32
56
33
- netInterface := "any" // Assuming "any" is the default interface
57
+ if port == 0 {
58
+ return nil
59
+ }
34
60
35
61
if ! c .enabled {
36
- if remove {
37
- if _ , ok := c .allowedInputPorts [port ]; ! ok {
38
- return nil
39
- }
40
- if _ , ok := c.allowedInputPorts [port ][netInterface ]; ! ok {
41
- return nil
42
- }
43
- delete (c .allowedInputPorts [port ], netInterface )
44
- if len (c .allowedInputPorts [port ]) == 0 {
45
- delete (c .allowedInputPorts , port )
46
- }
47
- } else {
48
- if _ , ok := c .allowedInputPorts [port ]; ! ok {
49
- c .allowedInputPorts [port ] = make (map [string ]struct {})
50
- }
51
- c.allowedInputPorts [port ][netInterface ] = struct {}{}
52
- }
62
+ c .logger .Info ("firewall disabled, only updating allowed ports internal list" )
63
+ delete (c .allowedInputPorts , port )
53
64
return nil
54
65
}
55
66
56
- if remove {
57
- c .logger .Info (fmt .Sprintf ("removing allowed port %d..." , port ))
58
- if err := c .acceptInputToPort (ctx , netInterface , port , remove ); err != nil {
59
- return fmt .Errorf ("removing port %d: %w" , port , err )
60
- }
61
- if _ , ok := c .allowedInputPorts [port ]; ! ok {
62
- return nil
63
- }
64
- if _ , ok := c.allowedInputPorts [port ][netInterface ]; ! ok {
65
- return nil
66
- }
67
- delete (c .allowedInputPorts [port ], netInterface )
68
- if len (c .allowedInputPorts [port ]) == 0 {
69
- delete (c .allowedInputPorts , port )
70
- }
71
- } else {
72
- if err := c .acceptInputToPort (ctx , netInterface , port , remove ); err != nil {
73
- return fmt .Errorf ("adding port %d: %w" , port , err )
74
- }
75
- if _ , ok := c .allowedInputPorts [port ]; ! ok {
76
- c .allowedInputPorts [port ] = make (map [string ]struct {})
67
+ c .logger .Info ("removing allowed port " + strconv .Itoa (int (port )) + "..." )
68
+
69
+ interfacesSet , ok := c .allowedInputPorts [port ]
70
+ if ! ok {
71
+ return nil
72
+ }
73
+
74
+ const remove = true
75
+ for netInterface := range interfacesSet {
76
+ err := c .acceptInputToPort (ctx , netInterface , port , remove )
77
+ if err != nil {
78
+ return fmt .Errorf ("removing allowed port %d on interface %s: %w" ,
79
+ port , netInterface , err )
77
80
}
78
- c. allowedInputPorts [ port ][ netInterface ] = struct {}{}
81
+ delete ( interfacesSet , netInterface )
79
82
}
80
83
81
- return nil
82
- }
84
+ // All interfaces were removed successfully, so remove the port entry.
85
+ delete ( c . allowedInputPorts , port )
83
86
84
- // RemoveAllowedPort removes a port from the allowed list for incoming traffic.
85
- func (c * Config ) RemoveAllowedPort (ctx context.Context , port uint16 ) (err error ) {
86
- return c .SetAllowedPort (ctx , port , true )
87
+ // Apply user-defined post firewall rules
88
+ if err := c .runUserPostRules (ctx , c .customRulesPath , false ); err != nil {
89
+ return fmt .Errorf ("running user defined post firewall rules: %w" , err )
90
+ }
91
+
92
+ return nil
87
93
}
88
94
89
95
func (c * Config ) SetPortRedirection (ctx context.Context , interfaceName string ,
0 commit comments