@@ -15,7 +15,7 @@ import (
15
15
16
16
var (
17
17
blocking = false
18
- defaultDomains = []string {"github.com" , "api.github.com" , "*.actions.githubusercontent.com" , "results-receiver.actions.githubusercontent.com" , "*.blob.core.windows.net" }
18
+ defaultDomains = []string {"github.com" , "api.github.com" , "*.actions.githubusercontent.com" , "results-receiver.actions.githubusercontent.com" , "*.blob.core.windows.net" , "*.githubapp.com" }
19
19
defaultIps = []string {"168.63.129.16" , "169.254.169.254" , "127.0.0.1" }
20
20
defaultDNSServers = []string {"127.0.0.53" }
21
21
)
@@ -27,6 +27,7 @@ const (
27
27
EGRESS_POLICY_AUDIT = "audit"
28
28
DNS_POLICY_ALLOWED_DOMAINS_ONLY = "allowed-domains-only"
29
29
DNS_POLICY_ANY = "any"
30
+ DNS_PORT = layers .TCPPort (53 )
30
31
)
31
32
32
33
type AgentConfig struct {
@@ -63,6 +64,7 @@ func NewAgent(config AgentConfig) *Agent {
63
64
netInfoProvider : config .NetInfoProvider ,
64
65
filesystem : config .FileSystem ,
65
66
}
67
+
66
68
agent .init (config )
67
69
return agent
68
70
}
@@ -153,7 +155,7 @@ func (a *Agent) loadAllowedIp(ips []string) {
153
155
a .addIpToLogs ("allowed" , "unknown" , ip )
154
156
continue
155
157
}
156
- fmt .Printf ("Failed to parse IP : %s. Skipping .\n " , ip )
158
+ fmt .Printf ("failed to parse ip : %s. skipping .\n " , ip )
157
159
}
158
160
}
159
161
@@ -164,13 +166,13 @@ func (a *Agent) addToFirewall(ips map[string]bool, cidr []*net.IPNet) error {
164
166
for ip := range ips {
165
167
err := a .firewall .AddIp (ip )
166
168
if err != nil {
167
- return fmt .Errorf ("Error adding %s to firewall: %v\n " , ip , err )
169
+ return fmt .Errorf ("error adding %s to firewall: %v" , ip , err )
168
170
}
169
171
}
170
172
for _ , c := range cidr {
171
173
err := a .firewall .AddIp (c .String ())
172
174
if err != nil {
173
- return fmt .Errorf ("Error adding %s to firewall: %v\n " , c .String (), err )
175
+ return fmt .Errorf ("error adding %s to firewall: %v" , c .String (), err )
174
176
}
175
177
}
176
178
return nil
@@ -228,22 +230,19 @@ func (a *Agent) loadAllowedDNSServers() error {
228
230
}
229
231
230
232
func getDestinationIP (packet gopacket.Packet ) (string , error ) {
231
- ipLayer := packet .Layer (layers .LayerTypeIPv4 )
232
- if ipLayer == nil {
233
- ipLayer = packet .Layer (layers .LayerTypeIPv6 )
234
- }
235
- if ipLayer == nil {
236
- return "" , fmt .Errorf ("Failed to get IP layer" )
233
+ netLayer := packet .NetworkLayer ()
234
+ if netLayer == nil {
235
+ return "" , fmt .Errorf ("failed to get network layer" )
237
236
}
238
- ip , _ := ipLayer .(* layers.IPv4 )
239
- if ip == nil {
240
- ip6 , _ := ipLayer .(* layers.IPv6 )
241
- if ip6 == nil {
242
- return "" , fmt .Errorf ("Failed to get IP layer" )
243
- }
244
- return ip6 .DstIP .String (), nil
237
+
238
+ switch v := netLayer .(type ) {
239
+ case * layers.IPv4 :
240
+ return v .DstIP .String (), nil
241
+ case * layers.IPv6 :
242
+ return v .DstIP .String (), nil
243
+ default :
244
+ return "" , fmt .Errorf ("unknown network layer type" )
245
245
}
246
- return ip .DstIP .String (), nil
247
246
}
248
247
249
248
func extractDomainFromSRV (domain string ) string {
@@ -254,25 +253,18 @@ func extractDomainFromSRV(domain string) string {
254
253
return re .ReplaceAllString (domain , "" )
255
254
}
256
255
257
- func (a * Agent ) processDNSQuery (packet gopacket.Packet ) uint8 {
258
- dnsLayer := packet .Layer (layers .LayerTypeDNS )
259
- dns , _ := dnsLayer .(* layers.DNS )
256
+ func (a * Agent ) processDNSLayer (dns * layers.DNS ) uint8 {
257
+ if ! dns .QR {
258
+ return a .processDNSQuery (dns )
259
+ }
260
+ return a .processDNSResponse (dns )
261
+ }
262
+
263
+ func (a * Agent ) processDNSQuery (dns * layers.DNS ) uint8 {
260
264
for _ , q := range dns .Questions {
261
265
domain := string (q .Name )
262
266
fmt .Printf ("DNS Question: %s %s\n " , q .Name , q .Type )
263
267
264
- // making sure the DNS query is using a trusted DNS server
265
- destinationIP , err := getDestinationIP (packet )
266
- if err != nil {
267
- fmt .Println ("Failed to get destination IP" )
268
- a .addIpToLogs ("blocked" , domain , "unknown" )
269
- return DROP_REQUEST
270
- }
271
- if ! a .allowedDNSServers [destinationIP ] {
272
- fmt .Printf ("%s -> Blocked DNS Query. Untrusted DNS server %s\n " , domain , destinationIP )
273
- a .addIpToLogs ("blocked" , domain , "unknown" )
274
- return DROP_REQUEST
275
- }
276
268
if q .Type == layers .DNSTypeSRV {
277
269
originalDomain := domain
278
270
domain = extractDomainFromSRV (domain )
@@ -345,11 +337,10 @@ func (a *Agent) processDNSTypeSRVResponse(domain string, answer *layers.DNSResou
345
337
}
346
338
}
347
339
348
- func (a * Agent ) processDNSResponse (packet gopacket.Packet ) uint8 {
349
- dnsLayer := packet .Layer (layers .LayerTypeDNS )
350
- dns , _ := dnsLayer .(* layers.DNS )
340
+ func (a * Agent ) processDNSResponse (dns * layers.DNS ) uint8 {
351
341
domain := string (dns .Questions [0 ].Name )
352
342
for _ , answer := range dns .Answers {
343
+ fmt .Printf ("DNS Answer: %s %s %s\n " , answer .Name , answer .Type , answer .IP )
353
344
if answer .Type == layers .DNSTypeA {
354
345
a .processDNSTypeAResponse (domain , & answer )
355
346
} else if answer .Type == layers .DNSTypeCNAME {
@@ -365,21 +356,108 @@ func (a *Agent) processDNSResponse(packet gopacket.Packet) uint8 {
365
356
return ACCEPT_REQUEST
366
357
}
367
358
368
- func (a * Agent ) ProcessPacket (packet gopacket.Packet ) uint8 {
369
- if dnsLayer := packet .Layer (layers .LayerTypeDNS ); dnsLayer != nil {
359
+ func (a * Agent ) processDNSPacket (packet gopacket.Packet ) uint8 {
360
+ dnsLayer := packet .Layer (layers .LayerTypeDNS )
361
+ dns , _ := dnsLayer .(* layers.DNS )
362
+ for _ , q := range dns .Questions {
363
+ fmt .Printf ("DNS Question: %s %s\n " , q .Name , q .Type )
364
+ }
370
365
371
- dns , _ := dnsLayer .(* layers.DNS )
372
- for _ , q := range dns .Questions {
373
- fmt .Printf ("DNS Question: %s %s\n " , q .Name , q .Type )
366
+ domain := string (dns .Questions [0 ].Name )
367
+ // if we are blocking DNS queries, intercept the DNS queries and decide whether to block or allow them
368
+ if ! dns .QR {
369
+ // making sure the DNS query is using a trusted DNS server
370
+ destinationIP , err := getDestinationIP (packet )
371
+ if err != nil {
372
+ fmt .Printf ("Failed to get destination IP: %v\n " , err )
373
+ a .addIpToLogs ("blocked" , domain , "unknown" )
374
+ return DROP_REQUEST
374
375
}
375
- // if we are blocking DNS queries, intercept the DNS queries and decide whether to block or allow them
376
- if a .blockDNS && ! dns .QR {
377
- return a .processDNSQuery (packet )
378
- } else if dns .QR {
379
- return a .processDNSResponse (packet )
376
+ if ! a .allowedDNSServers [destinationIP ] {
377
+ fmt .Printf ("%s -> Blocked DNS Query. Untrusted DNS server %s\n " , domain , destinationIP )
378
+ a .addIpToLogs ("blocked" , domain , destinationIP )
379
+ return DROP_REQUEST
380
380
}
381
381
}
382
- return ACCEPT_REQUEST
382
+
383
+ // if we are not blocking DNS queries, just accept the query request
384
+ if ! a .blockDNS && ! dns .QR {
385
+ return ACCEPT_REQUEST
386
+ }
387
+ return a .processDNSLayer (dns )
388
+ }
389
+
390
+ func (a * Agent ) processDNSOverTCPPayload (payload []byte ) uint8 {
391
+ // Extract message length from first 2 bytes
392
+ // - First byte shifted left 8 bits + second byte
393
+ // - Creates 16-bit length prefix
394
+ messageLen := int (payload [0 ])<< 8 | int (payload [1 ])
395
+ if messageLen == 0 || len (payload ) < messageLen + 2 {
396
+ fmt .Println ("Invalid DNS over TCP payload" )
397
+ return DROP_REQUEST
398
+ }
399
+
400
+ // We attempt to decode the DNS over TCP payload
401
+ // The only way we can accept the request is if the DNS query is contained within a single TCP packet payload
402
+ dns := & layers.DNS {}
403
+ err := dns .DecodeFromBytes (payload [2 :messageLen + 2 ], gopacket .NilDecodeFeedback )
404
+ if err != nil {
405
+ fmt .Println ("Failed to decode DNS over TCP payload" , err )
406
+ return DROP_REQUEST
407
+ }
408
+ return a .processDNSLayer (dns )
409
+ }
410
+
411
+ func (a * Agent ) processTCPPacket (packet gopacket.Packet ) uint8 {
412
+ tcpLayer := packet .Layer (layers .LayerTypeTCP )
413
+ tcp , _ := tcpLayer .(* layers.TCP )
414
+ dstPort , srcPort , payload := tcp .DstPort , tcp .SrcPort , tcp .Payload
415
+
416
+ // Validate DNS server IP
417
+ if dstPort == DNS_PORT {
418
+ destinationIP , err := getDestinationIP (packet )
419
+ if err != nil {
420
+ fmt .Printf ("Failed to get destination IP: %v\n " , err )
421
+ a .addIpToLogs ("blocked" , "unknown" , "unknown" )
422
+ return DROP_REQUEST
423
+ }
424
+ if ! a .allowedDNSServers [destinationIP ] {
425
+ fmt .Printf ("%s -> Blocked DNS Query. Untrusted DNS server %s\n " , "unknown" , destinationIP )
426
+ a .addIpToLogs ("blocked" , "unknown" , destinationIP )
427
+ return DROP_REQUEST
428
+ }
429
+ }
430
+
431
+ if dstPort != DNS_PORT && srcPort != DNS_PORT {
432
+ fmt .Println ("Warning: Destination and source port are not DNS ports. Dropping request" )
433
+ return DROP_REQUEST
434
+ }
435
+
436
+ // if we are not blocking DNS queries, just accept the query request
437
+ if ! a .blockDNS && dstPort == DNS_PORT {
438
+ return ACCEPT_REQUEST
439
+ }
440
+
441
+ if len (payload ) == 0 {
442
+ // We only accept DNS over TCP packets with no payload since they are only used for initiating a connection
443
+ return ACCEPT_REQUEST
444
+ }
445
+
446
+ // Now we have a payload in the TCP packet, we need to make sure it is a valid DNS over TCP payload and the DNS query is for a known domain. We don't want to exfiltrate data over DNS over TCP
447
+ return a .processDNSOverTCPPayload (payload )
448
+
449
+ }
450
+
451
+ func (a * Agent ) ProcessPacket (packet gopacket.Packet ) uint8 {
452
+ if dnsLayer := packet .Layer (layers .LayerTypeDNS ); dnsLayer != nil {
453
+ return a .processDNSPacket (packet )
454
+ }
455
+ // check dns over tcp
456
+ if tcpLayer := packet .Layer (layers .LayerTypeTCP ); tcpLayer != nil {
457
+ return a .processTCPPacket (packet )
458
+ }
459
+ fmt .Println ("Warning: Packet is not DNS or TCP. Dropping request, this shouldn't be happening." )
460
+ return DROP_REQUEST
383
461
}
384
462
385
463
func (a * Agent ) disableSudo () error {
0 commit comments