Skip to content

Commit ae7744a

Browse files
authored
fix: dns over tcp bypasses domain filtering (#175)
- Fixing a bug where it was possible to bypass the domain filtering when using DNS over TCP. - Also showing the untrusted DNS server in annotations
1 parent 72cb2a3 commit ae7744a

File tree

9 files changed

+156
-65
lines changed

9 files changed

+156
-65
lines changed

action/dist/post.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19914,7 +19914,7 @@ async function getOutboundConnections() {
1991419914
);
1991519915
console.log("Agent ready timestamp: ", agentReadyTimestamp);
1991619916
const tetragonLogFile = await import_promises2.default.open(TETRAGON_EVENTS_LOG_PATH);
19917-
const functionsToTrack = ["tcp_connect"];
19917+
const functionsToTrack = ["tcp_connect", "udp_sendmsg"];
1991819918
for await (const line of tetragonLogFile.readLines()) {
1991919919
const processEntry = JSON.parse(line.trimEnd())?.process_kprobe;
1992019920
if (processEntry?.["policy_name"] !== "connect") {

action/dist/post.js.map

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

action/src/post.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ async function getOutboundConnections(): Promise<TetragonLog[]> {
7878

7979
const tetragonLogFile = await fs.open(TETRAGON_EVENTS_LOG_PATH);
8080

81-
const functionsToTrack = ["tcp_connect"];
81+
const functionsToTrack = ["tcp_connect", "udp_sendmsg"];
8282

8383
for await (const line of tetragonLogFile.readLines()) {
8484
const processEntry = JSON.parse(line.trimEnd())?.process_kprobe;

action/tetragon/connect.yml

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ spec:
1515
operator: "NotDAddr"
1616
values:
1717
- 127.0.0.1
18+
- 127.0.0.53
1819
- call: "udp_sendmsg"
1920
syscall: false
2021
args:
@@ -26,22 +27,11 @@ spec:
2627
operator: "NotDAddr"
2728
values:
2829
- 127.0.0.1
30+
- 127.0.0.53
2931
- call: "udp_recvmsg"
3032
syscall: false
3133
args:
3234
- index: 0
3335
type: "sock"
3436
- index: 2
3537
type: "size_t"
36-
# - call: "tcp_close"
37-
# syscall: false
38-
# args:
39-
# - index: 0
40-
# type: "sock"
41-
# - call: "tcp_sendmsg"
42-
# syscall: false
43-
# args:
44-
# - index: 0
45-
# type: "sock"
46-
# - index: 2
47-
# type: int

agent/agent.go

Lines changed: 125 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515

1616
var (
1717
blocking = false
18-
defaultDomains = []string{"github.com", "api.github.com", "*.actions.githubusercontent.com", "results-receiver.actions.githubusercontent.com", "*.blob.core.windows.net"}
18+
defaultDomains = []string{"github.com", "api.github.com", "*.actions.githubusercontent.com", "results-receiver.actions.githubusercontent.com", "*.blob.core.windows.net", "*.githubapp.com"}
1919
defaultIps = []string{"168.63.129.16", "169.254.169.254", "127.0.0.1"}
2020
defaultDNSServers = []string{"127.0.0.53"}
2121
)
@@ -27,6 +27,7 @@ const (
2727
EGRESS_POLICY_AUDIT = "audit"
2828
DNS_POLICY_ALLOWED_DOMAINS_ONLY = "allowed-domains-only"
2929
DNS_POLICY_ANY = "any"
30+
DNS_PORT = layers.TCPPort(53)
3031
)
3132

3233
type AgentConfig struct {
@@ -63,6 +64,7 @@ func NewAgent(config AgentConfig) *Agent {
6364
netInfoProvider: config.NetInfoProvider,
6465
filesystem: config.FileSystem,
6566
}
67+
6668
agent.init(config)
6769
return agent
6870
}
@@ -153,7 +155,7 @@ func (a *Agent) loadAllowedIp(ips []string) {
153155
a.addIpToLogs("allowed", "unknown", ip)
154156
continue
155157
}
156-
fmt.Printf("Failed to parse IP: %s. Skipping.\n", ip)
158+
fmt.Printf("failed to parse ip: %s. skipping.\n", ip)
157159
}
158160
}
159161

@@ -164,13 +166,13 @@ func (a *Agent) addToFirewall(ips map[string]bool, cidr []*net.IPNet) error {
164166
for ip := range ips {
165167
err := a.firewall.AddIp(ip)
166168
if err != nil {
167-
return fmt.Errorf("Error adding %s to firewall: %v\n", ip, err)
169+
return fmt.Errorf("error adding %s to firewall: %v", ip, err)
168170
}
169171
}
170172
for _, c := range cidr {
171173
err := a.firewall.AddIp(c.String())
172174
if err != nil {
173-
return fmt.Errorf("Error adding %s to firewall: %v\n", c.String(), err)
175+
return fmt.Errorf("error adding %s to firewall: %v", c.String(), err)
174176
}
175177
}
176178
return nil
@@ -228,22 +230,19 @@ func (a *Agent) loadAllowedDNSServers() error {
228230
}
229231

230232
func getDestinationIP(packet gopacket.Packet) (string, error) {
231-
ipLayer := packet.Layer(layers.LayerTypeIPv4)
232-
if ipLayer == nil {
233-
ipLayer = packet.Layer(layers.LayerTypeIPv6)
234-
}
235-
if ipLayer == nil {
236-
return "", fmt.Errorf("Failed to get IP layer")
233+
netLayer := packet.NetworkLayer()
234+
if netLayer == nil {
235+
return "", fmt.Errorf("failed to get network layer")
237236
}
238-
ip, _ := ipLayer.(*layers.IPv4)
239-
if ip == nil {
240-
ip6, _ := ipLayer.(*layers.IPv6)
241-
if ip6 == nil {
242-
return "", fmt.Errorf("Failed to get IP layer")
243-
}
244-
return ip6.DstIP.String(), nil
237+
238+
switch v := netLayer.(type) {
239+
case *layers.IPv4:
240+
return v.DstIP.String(), nil
241+
case *layers.IPv6:
242+
return v.DstIP.String(), nil
243+
default:
244+
return "", fmt.Errorf("unknown network layer type")
245245
}
246-
return ip.DstIP.String(), nil
247246
}
248247

249248
func extractDomainFromSRV(domain string) string {
@@ -254,25 +253,18 @@ func extractDomainFromSRV(domain string) string {
254253
return re.ReplaceAllString(domain, "")
255254
}
256255

257-
func (a *Agent) processDNSQuery(packet gopacket.Packet) uint8 {
258-
dnsLayer := packet.Layer(layers.LayerTypeDNS)
259-
dns, _ := dnsLayer.(*layers.DNS)
256+
func (a *Agent) processDNSLayer(dns *layers.DNS) uint8 {
257+
if !dns.QR {
258+
return a.processDNSQuery(dns)
259+
}
260+
return a.processDNSResponse(dns)
261+
}
262+
263+
func (a *Agent) processDNSQuery(dns *layers.DNS) uint8 {
260264
for _, q := range dns.Questions {
261265
domain := string(q.Name)
262266
fmt.Printf("DNS Question: %s %s\n", q.Name, q.Type)
263267

264-
// making sure the DNS query is using a trusted DNS server
265-
destinationIP, err := getDestinationIP(packet)
266-
if err != nil {
267-
fmt.Println("Failed to get destination IP")
268-
a.addIpToLogs("blocked", domain, "unknown")
269-
return DROP_REQUEST
270-
}
271-
if !a.allowedDNSServers[destinationIP] {
272-
fmt.Printf("%s -> Blocked DNS Query. Untrusted DNS server %s\n", domain, destinationIP)
273-
a.addIpToLogs("blocked", domain, "unknown")
274-
return DROP_REQUEST
275-
}
276268
if q.Type == layers.DNSTypeSRV {
277269
originalDomain := domain
278270
domain = extractDomainFromSRV(domain)
@@ -345,11 +337,10 @@ func (a *Agent) processDNSTypeSRVResponse(domain string, answer *layers.DNSResou
345337
}
346338
}
347339

348-
func (a *Agent) processDNSResponse(packet gopacket.Packet) uint8 {
349-
dnsLayer := packet.Layer(layers.LayerTypeDNS)
350-
dns, _ := dnsLayer.(*layers.DNS)
340+
func (a *Agent) processDNSResponse(dns *layers.DNS) uint8 {
351341
domain := string(dns.Questions[0].Name)
352342
for _, answer := range dns.Answers {
343+
fmt.Printf("DNS Answer: %s %s %s\n", answer.Name, answer.Type, answer.IP)
353344
if answer.Type == layers.DNSTypeA {
354345
a.processDNSTypeAResponse(domain, &answer)
355346
} else if answer.Type == layers.DNSTypeCNAME {
@@ -365,21 +356,108 @@ func (a *Agent) processDNSResponse(packet gopacket.Packet) uint8 {
365356
return ACCEPT_REQUEST
366357
}
367358

368-
func (a *Agent) ProcessPacket(packet gopacket.Packet) uint8 {
369-
if dnsLayer := packet.Layer(layers.LayerTypeDNS); dnsLayer != nil {
359+
func (a *Agent) processDNSPacket(packet gopacket.Packet) uint8 {
360+
dnsLayer := packet.Layer(layers.LayerTypeDNS)
361+
dns, _ := dnsLayer.(*layers.DNS)
362+
for _, q := range dns.Questions {
363+
fmt.Printf("DNS Question: %s %s\n", q.Name, q.Type)
364+
}
370365

371-
dns, _ := dnsLayer.(*layers.DNS)
372-
for _, q := range dns.Questions {
373-
fmt.Printf("DNS Question: %s %s\n", q.Name, q.Type)
366+
domain := string(dns.Questions[0].Name)
367+
// if we are blocking DNS queries, intercept the DNS queries and decide whether to block or allow them
368+
if !dns.QR {
369+
// making sure the DNS query is using a trusted DNS server
370+
destinationIP, err := getDestinationIP(packet)
371+
if err != nil {
372+
fmt.Printf("Failed to get destination IP: %v\n", err)
373+
a.addIpToLogs("blocked", domain, "unknown")
374+
return DROP_REQUEST
374375
}
375-
// if we are blocking DNS queries, intercept the DNS queries and decide whether to block or allow them
376-
if a.blockDNS && !dns.QR {
377-
return a.processDNSQuery(packet)
378-
} else if dns.QR {
379-
return a.processDNSResponse(packet)
376+
if !a.allowedDNSServers[destinationIP] {
377+
fmt.Printf("%s -> Blocked DNS Query. Untrusted DNS server %s\n", domain, destinationIP)
378+
a.addIpToLogs("blocked", domain, destinationIP)
379+
return DROP_REQUEST
380380
}
381381
}
382-
return ACCEPT_REQUEST
382+
383+
// if we are not blocking DNS queries, just accept the query request
384+
if !a.blockDNS && !dns.QR {
385+
return ACCEPT_REQUEST
386+
}
387+
return a.processDNSLayer(dns)
388+
}
389+
390+
func (a *Agent) processDNSOverTCPPayload(payload []byte) uint8 {
391+
// Extract message length from first 2 bytes
392+
// - First byte shifted left 8 bits + second byte
393+
// - Creates 16-bit length prefix
394+
messageLen := int(payload[0])<<8 | int(payload[1])
395+
if messageLen == 0 || len(payload) < messageLen+2 {
396+
fmt.Println("Invalid DNS over TCP payload")
397+
return DROP_REQUEST
398+
}
399+
400+
// We attempt to decode the DNS over TCP payload
401+
// The only way we can accept the request is if the DNS query is contained within a single TCP packet payload
402+
dns := &layers.DNS{}
403+
err := dns.DecodeFromBytes(payload[2:messageLen+2], gopacket.NilDecodeFeedback)
404+
if err != nil {
405+
fmt.Println("Failed to decode DNS over TCP payload", err)
406+
return DROP_REQUEST
407+
}
408+
return a.processDNSLayer(dns)
409+
}
410+
411+
func (a *Agent) processTCPPacket(packet gopacket.Packet) uint8 {
412+
tcpLayer := packet.Layer(layers.LayerTypeTCP)
413+
tcp, _ := tcpLayer.(*layers.TCP)
414+
dstPort, srcPort, payload := tcp.DstPort, tcp.SrcPort, tcp.Payload
415+
416+
// Validate DNS server IP
417+
if dstPort == DNS_PORT {
418+
destinationIP, err := getDestinationIP(packet)
419+
if err != nil {
420+
fmt.Printf("Failed to get destination IP: %v\n", err)
421+
a.addIpToLogs("blocked", "unknown", "unknown")
422+
return DROP_REQUEST
423+
}
424+
if !a.allowedDNSServers[destinationIP] {
425+
fmt.Printf("%s -> Blocked DNS Query. Untrusted DNS server %s\n", "unknown", destinationIP)
426+
a.addIpToLogs("blocked", "unknown", destinationIP)
427+
return DROP_REQUEST
428+
}
429+
}
430+
431+
if dstPort != DNS_PORT && srcPort != DNS_PORT {
432+
fmt.Println("Warning: Destination and source port are not DNS ports. Dropping request")
433+
return DROP_REQUEST
434+
}
435+
436+
// if we are not blocking DNS queries, just accept the query request
437+
if !a.blockDNS && dstPort == DNS_PORT {
438+
return ACCEPT_REQUEST
439+
}
440+
441+
if len(payload) == 0 {
442+
// We only accept DNS over TCP packets with no payload since they are only used for initiating a connection
443+
return ACCEPT_REQUEST
444+
}
445+
446+
// Now we have a payload in the TCP packet, we need to make sure it is a valid DNS over TCP payload and the DNS query is for a known domain. We don't want to exfiltrate data over DNS over TCP
447+
return a.processDNSOverTCPPayload(payload)
448+
449+
}
450+
451+
func (a *Agent) ProcessPacket(packet gopacket.Packet) uint8 {
452+
if dnsLayer := packet.Layer(layers.LayerTypeDNS); dnsLayer != nil {
453+
return a.processDNSPacket(packet)
454+
}
455+
// check dns over tcp
456+
if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil {
457+
return a.processTCPPacket(packet)
458+
}
459+
fmt.Println("Warning: Packet is not DNS or TCP. Dropping request, this shouldn't be happening.")
460+
return DROP_REQUEST
383461
}
384462

385463
func (a *Agent) disableSudo() error {

agent/agent.sha256

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
9dc0fa0203ab625fa501ac1728aef69a5424267a7839c5fee9bf0d85e459ac59 agent
1+
72fdea56ef365e362fd5ae69c1d5866fc4636db4e8e4aa10855d158d3f5e439f agent

test/block.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ grep --quiet 'Blocked DNS request to www.bing.com from unknown process' $POST_WA
3838
grep --quiet 'Blocked request to 93.184.215.14:443 from processs `/usr/bin/curl https://93.184.215.14 --output /dev/null' $POST_WARNINGS_FILEPATH
3939
grep --quiet 'Blocked DNS request to registry-1.docker.io from unknown process' $POST_WARNINGS_FILEPATH
4040
grep --quiet 'Blocked DNS request to www.wikipedia.org from unknown process' $POST_WARNINGS_FILEPATH
41-
grep --quiet 'Blocked DNS request to www.google.com from unknown process' $POST_WARNINGS_FILEPATH
41+
grep --quiet 'Blocked DNS request to tcp.example.com from unknown process' $POST_WARNINGS_FILEPATH
42+
grep --quiet 'Blocked request to www.google.com (8.8.8.8:53) from process `/usr/bin/dig @8.8.8.8 www.google.com`' $POST_WARNINGS_FILEPATH
43+
grep --quiet 'Blocked request to www.google.com (8.8.8.8:53) from process `/usr/bin/dig @8.8.8.8 www.google.com +tcp`' $POST_WARNINGS_FILEPATH
4244

4345
echo "Tests passed successfully"

test/input.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
process.env["INPUT_EGRESS-POLICY"] = "block";
22
process.env["INPUT_DNS-POLICY"] = "allowed-domains-only";
33
process.env["INPUT__LOG-DIRECTORY"] = "/tmp/gha-agent/logs";
4+
process.env["INPUT_ENABLE-SUDO"] = "true";
45

56
process.env["INPUT_ALLOWED-IPS"] = `
67
10.0.0.0/24

test/make_dns_requests.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@ if timeout 5 dig example.com; then
55
exit 1
66
fi
77

8+
if timeout 5 dig tcp.example.com +tcp; then
9+
echo 'Expected 'dig tcp.example.com +tcp' to fail, but it succeeded'
10+
exit 1
11+
fi
12+
13+
if ! timeout 5 dig www.google.com; then
14+
echo 'Expected 'dig www.google.com' to succeed, but it failed'
15+
exit 1
16+
fi
17+
18+
if ! timeout 5 dig www.google.com +tcp; then
19+
echo 'Expected 'dig www.google.com +tcp' to succeed, but it failed'
20+
exit 1
21+
fi
22+
823
if timeout 5 dig www.wikipedia.org; then
924
echo 'Expected 'dig www.wikipedia.org' to fail, but it succeeded'
1025
exit 1
@@ -14,3 +29,8 @@ if timeout 5 dig @8.8.8.8 www.google.com; then
1429
echo 'Expected 'dig @8.8.8.8 www.google.com' to fail, but it succeeded'
1530
exit 1
1631
fi
32+
33+
if timeout 5 dig @8.8.8.8 www.google.com +tcp; then
34+
echo 'Expected 'dig @8.8.8.8 www.google.com +tcp' to fail, but it succeeded'
35+
exit 1
36+
fi

0 commit comments

Comments
 (0)