Skip to content

Commit 50fb2c3

Browse files
committed
feat(post-rules) execute post-rules after every update to the firewall
1 parent 9933dd3 commit 50fb2c3

File tree

5 files changed

+123
-106
lines changed

5 files changed

+123
-106
lines changed

internal/firewall/firewall.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ type Config struct { //nolint:maligned
3232
stateMutex sync.Mutex
3333
}
3434

35+
// applyUserPostRules applies user-defined post firewall rules
36+
func (c *Config) applyUserPostRules(ctx context.Context) error {
37+
const remove = false
38+
return c.runUserPostRules(ctx, c.customRulesPath, remove)
39+
}
40+
3541
// NewConfig creates a new Config instance and returns an error
3642
// if no iptables implementation is available.
3743
func NewConfig(ctx context.Context, logger Logger,

internal/firewall/iptables.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,15 @@ func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove b
264264
} else if err != nil {
265265
return err
266266
}
267+
268+
// Log when post-rules are being applied
269+
if !remove {
270+
c.logger.Info("applying user-defined post firewall rules from " + filepath)
271+
} else {
272+
c.logger.Info("removing user-defined post firewall rules from " + filepath)
273+
}
274+
275+
267276
b, err := io.ReadAll(file)
268277
if err != nil {
269278
_ = file.Close()

internal/firewall/outboundsubnets.go

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,47 @@ import (
99
"github.com/qdm12/gluetun/internal/subnet"
1010
)
1111

12-
func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []netip.Prefix) (err error) {
13-
c.stateMutex.Lock()
14-
defer c.stateMutex.Unlock()
15-
16-
if !c.enabled {
17-
c.logger.Info("firewall disabled, only updating allowed subnets internal list")
18-
c.outboundSubnets = make([]netip.Prefix, len(subnets))
19-
copy(c.outboundSubnets, subnets)
20-
return nil
21-
}
12+
func (c *Config) SetOutboundSubnets(ctx context.Context, outboundSubnets []netip.Prefix) (err error) {
13+
c.stateMutex.Lock()
14+
defer c.stateMutex.Unlock()
2215

23-
c.logger.Info("setting allowed subnets...")
16+
if !c.enabled {
17+
c.outboundSubnets = outboundSubnets
18+
return nil
19+
}
2420

25-
subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets)
26-
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
27-
return nil
28-
}
21+
// Remove previous outbound subnet rules
22+
for _, subnet := range c.outboundSubnets {
23+
subnetIsIPv6 := subnet.Addr().Is6()
24+
for _, defaultRoute := range c.defaultRoutes {
25+
defaultRouteIsIPv6 := defaultRoute.Family == netlink.FamilyV6
26+
ipFamilyMatch := subnetIsIPv6 == defaultRouteIsIPv6
27+
if !ipFamilyMatch {
28+
continue
29+
}
2930

30-
c.removeOutboundSubnets(ctx, subnetsToRemove)
31-
if err := c.addOutboundSubnets(ctx, subnetsToAdd); err != nil {
32-
return fmt.Errorf("setting allowed outbound subnets: %w", err)
33-
}
31+
const remove = true
32+
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
33+
defaultRoute.AssignedIP, subnet, remove)
34+
if err != nil {
35+
return err
36+
}
37+
}
38+
}
3439

35-
return nil
40+
c.outboundSubnets = outboundSubnets
41+
42+
// Add new outbound subnet rules
43+
if err = c.allowOutboundSubnets(ctx); err != nil {
44+
return err
45+
}
46+
47+
// Re-apply user post-rules after subnet changes
48+
if err = c.applyUserPostRules(ctx); err != nil {
49+
return fmt.Errorf("re-applying user post-rules after outbound subnet change: %w", err)
50+
}
51+
52+
return nil
3653
}
3754

3855
func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []netip.Prefix) {

internal/firewall/ports.go

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,42 +7,36 @@ import (
77
)
88

99
func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) {
10-
c.stateMutex.Lock()
11-
defer c.stateMutex.Unlock()
12-
13-
if port == 0 {
14-
return nil
15-
}
16-
17-
if !c.enabled {
18-
c.logger.Info("firewall disabled, only updating allowed ports internal state")
19-
existingInterfaces, ok := c.allowedInputPorts[port]
20-
if !ok {
21-
existingInterfaces = make(map[string]struct{})
22-
}
23-
existingInterfaces[intf] = struct{}{}
24-
c.allowedInputPorts[port] = existingInterfaces
25-
return nil
26-
}
27-
28-
netInterfaces, has := c.allowedInputPorts[port]
29-
if !has {
30-
netInterfaces = make(map[string]struct{})
31-
} else if _, exists := netInterfaces[intf]; exists {
32-
return nil
33-
}
34-
35-
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
36-
37-
const remove = false
38-
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
39-
return fmt.Errorf("allowing input to port %d through interface %s: %w",
40-
port, intf, err)
41-
}
42-
netInterfaces[intf] = struct{}{}
43-
c.allowedInputPorts[port] = netInterfaces
44-
45-
return nil
10+
c.stateMutex.Lock()
11+
defer c.stateMutex.Unlock()
12+
13+
interfaceSet, ok := c.allowedInputPorts[port]
14+
if !ok {
15+
interfaceSet = make(map[string]struct{})
16+
c.allowedInputPorts[port] = interfaceSet
17+
}
18+
19+
_, alreadySet := interfaceSet[intf]
20+
if alreadySet {
21+
return nil
22+
}
23+
24+
if c.enabled {
25+
const remove = false
26+
err = c.acceptInputToPort(ctx, intf, port, remove)
27+
if err != nil {
28+
return fmt.Errorf("accepting input port %d on interface %s: %w",
29+
port, intf, err)
30+
}
31+
32+
// ADD THIS: Re-apply user post-rules after port changes
33+
if err = c.applyUserPostRules(ctx); err != nil {
34+
return fmt.Errorf("re-applying user post-rules after port change: %w", err)
35+
}
36+
}
37+
38+
interfaceSet[intf] = struct{}{}
39+
return nil
4640
}
4741

4842
func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error) {

internal/firewall/vpn.go

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,54 +7,45 @@ import (
77
"github.com/qdm12/gluetun/internal/models"
88
)
99

10-
func (c *Config) SetVPNConnection(ctx context.Context,
11-
connection models.Connection, vpnIntf string,
12-
) (err error) {
13-
c.stateMutex.Lock()
14-
defer c.stateMutex.Unlock()
15-
16-
if !c.enabled {
17-
c.logger.Info("firewall disabled, only updating internal VPN connection")
18-
c.vpnConnection = connection
19-
return nil
20-
}
21-
22-
c.logger.Info("allowing VPN connection...")
23-
24-
if c.vpnConnection.Equal(connection) {
25-
return nil
26-
}
27-
28-
remove := true
29-
if c.vpnConnection.IP.IsValid() {
30-
for _, defaultRoute := range c.defaultRoutes {
31-
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
32-
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
33-
}
34-
}
35-
}
36-
c.vpnConnection = models.Connection{}
37-
38-
if c.vpnIntf != "" {
39-
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
40-
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
41-
}
42-
}
43-
c.vpnIntf = ""
44-
45-
remove = false
46-
47-
for _, defaultRoute := range c.defaultRoutes {
48-
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
49-
return fmt.Errorf("allowing output traffic through VPN connection: %w", err)
50-
}
51-
}
52-
c.vpnConnection = connection
53-
54-
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
55-
return fmt.Errorf("accepting output traffic through interface %s: %w", vpnIntf, err)
56-
}
57-
c.vpnIntf = vpnIntf
58-
59-
return nil
10+
func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connection, intf string) (err error) {
11+
c.stateMutex.Lock()
12+
defer c.stateMutex.Unlock()
13+
14+
if !c.enabled {
15+
c.vpnConnection = connection
16+
c.vpnIntf = intf
17+
return nil
18+
}
19+
20+
// Remove previous VPN rules
21+
if c.vpnConnection.IP.IsValid() {
22+
const remove = true
23+
interfacesSeen := make(map[string]struct{}, len(c.defaultRoutes))
24+
for _, defaultRoute := range c.defaultRoutes {
25+
_, seen := interfacesSeen[defaultRoute.NetInterface]
26+
if seen {
27+
continue
28+
}
29+
interfacesSeen[defaultRoute.NetInterface] = struct{}{}
30+
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
31+
if err != nil {
32+
return fmt.Errorf("removing output traffic through VPN: %w", err)
33+
}
34+
}
35+
}
36+
37+
c.vpnConnection = connection
38+
c.vpnIntf = intf
39+
40+
// Add new VPN rules
41+
if err = c.allowVPNIP(ctx); err != nil {
42+
return err
43+
}
44+
45+
// Re-apply user post-rules after VPN changes
46+
if err = c.applyUserPostRules(ctx); err != nil {
47+
return fmt.Errorf("re-applying user post-rules after VPN change: %w", err)
48+
}
49+
50+
return nil
6051
}

0 commit comments

Comments
 (0)