-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
143 lines (110 loc) · 4.28 KB
/
main.py
File metadata and controls
143 lines (110 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import ipaddress
import socket
import yaml
BUFFER_SIZE = 512
HEADER_SIZE = 12
CLIENT_TIMEOUT = 5
SERVER_TIMEOUT = 15
allowed_subnet: str | None = None
with open("config.yaml", "r") as file:
config = yaml.safe_load(file)
def forward_request(data: bytes) -> bytes | None:
# Create a UDP socket to communicate with the real DNS server
dns_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
dns_socket.settimeout(SERVER_TIMEOUT)
dns_socket.sendto(data, (config["dns_server"]["address"], config["dns_server"]["port"]))
# Receive the DNS response from the real DNS server
try:
response, _ = dns_socket.recvfrom(BUFFER_SIZE)
except socket.timeout:
print("Forwarding request timeout")
response = None
dns_socket.close()
return response
def is_blocked(client_address, domain: str) -> bool:
global allowed_subnet
# Check if the request wants to resolve the auth domain
if domain == config["auth_domain"]:
print("Updating allowed address")
allowed_subnet = ipaddress.ip_network(f"{client_address[0]}/{config['allowed_subnet']}", strict=False)
return True
# Check if no address is allowed
if allowed_subnet is None:
print("No address allowed. Call auth domain first")
return True
# Check if the request is from the allowed subnet
if ipaddress.ip_address(client_address[0]) not in allowed_subnet:
print("Address not allowed")
return True
# Check if the domain is in the blocked list
for blocked_domain in config["blocked_domains"]:
if blocked_domain == domain or (
blocked_domain.startswith("*.")
and domain.endswith(blocked_domain[2:])
):
print("Domain blocked")
return True
return False
def extract_domain_name(data: bytes) -> str | None:
# Skip the header
offset = HEADER_SIZE
domain_name = []
# Read domain name from the query section
while True:
length = data[offset]
if length == 0:
break
offset += 1
try:
domain_name.append(data[offset: offset + length].decode("utf-8"))
except UnicodeDecodeError:
return None
offset += length
return ".".join(domain_name)
def generate_block_response(data: bytes) -> bytes:
# Construct a DNS response indicating the query is blocked
transaction_id = data[:2]
flags = b"\x81\x83" # Standard query response, refused
qdcount = b"\x00\x01" # One question
ancount = b"\x00\x00" # Zero answers
nscount = b"\x00\x00" # Zero authority records
arcount = b"\x00\x00" # Zero additional records
# Return the header and the original query
return transaction_id + flags + qdcount + ancount + nscount + arcount + data[12:]
def main():
# Create a UDP socket to listen for incoming DNS requests
server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
server_socket.settimeout(CLIENT_TIMEOUT)
server_socket.bind((config["listen"]["address"], config["listen"]["port"]))
print(f"DNS Proxy listening on {config['listen']['address']}:{config['listen']['port']}")
while True:
# Receive DNS request from client
try:
data, client_address = server_socket.recvfrom(BUFFER_SIZE)
except socket.timeout:
print("Timeout")
continue
print(f"<<< Received request from {client_address}")
# Check if the request is too small
if len(data) <= HEADER_SIZE:
print("Request too small")
continue
domain = extract_domain_name(data)
if domain is None:
print("Malformed domain name")
continue
print(f"Queried domain: {domain}")
# Check if the request should be blocked
if is_blocked(client_address, domain):
print("Blocking request")
response = generate_block_response(data)
else:
# Forward the DNS request to the real DNS server
print("Forwarding request")
response = forward_request(data)
if response is None:
continue
# Send the DNS response back to the client
server_socket.sendto(response, client_address)
print(f">>> Sent response to {client_address}")
main()