-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathclient.py
More file actions
357 lines (289 loc) · 12.9 KB
/
client.py
File metadata and controls
357 lines (289 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
"""
S2 protocol client for handling pairing and secure connections.
"""
import abc
import json
import uuid
import datetime
import logging
from typing import Dict, Optional, Tuple, Union, List, Any
from jwskate import Jwk
from pydantic import BaseModel
from s2python.generated.gen_s2_pairing import (
ConnectionDetails,
ConnectionRequest,
PairingRequest,
PairingResponse,
PairingToken,
S2NodeDescription,
Protocols,
)
REQTEST_TIMEOUT = 10
PAIRING_TIMEOUT = datetime.timedelta(minutes=5)
KEY_ALGORITHM = "RSA-OAEP-256"
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("S2AbstractClient")
class PairingDetails(BaseModel):
"""Contains all details from the pairing process."""
pairing_response: PairingResponse
connection_details: ConnectionDetails
decrypted_challenge_str: Optional[str] = None
class S2AbstractClient(abc.ABC):
"""Abstract client for handling S2 protocol pairing and connections.
Client handles:
- HTTP client with TLS
- Storage of connection request URI
- Storage of public/private key pairs
- Challenge solving
This class serves as an interface that developers can extend to implement
S2 protocol functionality with their preferred technology stack.
Concrete implementations should override the abstract methods marked
with @abc.abstractmethod.
"""
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-arguments
def __init__(
self,
pairing_uri: Optional[str] = None,
token: Optional[PairingToken] = None,
node_description: Optional[S2NodeDescription] = None,
verify_certificate: Union[bool, str] = False,
client_node_id: Optional[uuid.UUID] = None,
supported_protocols: Optional[List[Protocols]] = None,
) -> None:
"""Initialize the client with configuration parameters.
Args:
pairing_uri: URI for the pairing request
token: Pairing token for authentication
node_description: S2 node description
verify_certificate: Whether to verify SSL certificates (or path to CA cert)
client_node_id: Client node UUID (generated if not provided)
supported_protocols: List of supported protocols
"""
# Connection and authentication info
self.pairing_uri = pairing_uri
self.token = token
self.node_description = node_description
self.verify_certificate = verify_certificate
self.client_node_id = client_node_id if client_node_id else uuid.uuid4()
self.supported_protocols = supported_protocols or [Protocols.WebSocketSecure]
# Internal state
self._connection_request_uri: Optional[str] = None
self._public_key: Optional[str] = None
self._private_key: Optional[str] = None
self._public_jwk: Optional[Jwk] = None
self._private_jwk: Optional[Jwk] = None
self._key_pair: Optional[Jwk] = None
self._pairing_response: Optional[PairingResponse] = None
self._connection_details: Optional[ConnectionDetails] = None
self._pairing_details: Optional[PairingDetails] = None
@property
def connection_request_uri(self) -> Optional[str]:
"""Get the stored connection request URI."""
return self._connection_request_uri
def store_connection_request_uri(self, uri: str) -> None:
"""Store the connection request URI.
If the provided URI is empty, None, or doesn't contain 'requestConnection',
it will attempt to derive it from the pairing URI by replacing 'requestPairing'
with 'requestConnection'.
Args:
uri: The connection request URI from the pairing response
"""
if uri is not None and uri.strip() != "" and "requestConnection" in uri:
self._connection_request_uri = uri
elif self.pairing_uri is not None and "requestPairing" in self.pairing_uri:
# Fall back to constructing the URI from the pairing URI
self._connection_request_uri = self.pairing_uri.replace(
"requestPairing", "requestConnection"
)
else:
# No valid URI could be determined
self._connection_request_uri = None
@abc.abstractmethod
def generate_key_pair(self) -> Tuple[str, str]:
"""Generate a public/private key pair.
This method should be implemented by concrete subclasses to use their
preferred cryptographic libraries or key management systems.
Returns:
Tuple[str, str]: (public_key, private_key) pair as base64 encoded strings
"""
@abc.abstractmethod
def store_key_pair(self, public_key: str, private_key: str) -> None:
"""Store the public/private key pair.
This method should be implemented by concrete subclasses to store keys
according to their security requirements (e.g., secure storage, HSM, etc.).
Args:
public_key: Base64 encoded public key
private_key: Base64 encoded private key
"""
@abc.abstractmethod
def _make_https_request(
self,
url: str,
method: str = "GET",
data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[int, str]:
"""Make an HTTPS request.
This method should be implemented by concrete subclasses to use their
preferred HTTP client library or framework.
Args:
url: Target URL
method: HTTP method (GET, POST, etc.)
data: Request body data
headers: HTTP headers
Returns:
Tuple[int, str]: (status_code, response_text)
"""
def request_pairing(self) -> PairingResponse:
"""Send a pairing request to the server using client configuration.
Returns:
PairingResponse: The server's response to the pairing request
Raises:
ValueError: If pairing_uri or token is not set, or if the request fails
"""
if not self.pairing_uri:
raise ValueError(
"Pairing URI not set. Set pairing_uri before calling request_pairing."
)
if not self.token:
raise ValueError(
"Pairing token not set. Set token before calling request_pairing."
)
# Ensure we have keys
if not self._public_key:
public_key, private_key = self.generate_key_pair()
self.store_key_pair(public_key, private_key)
# Create pairing request
logger.info("Creating pairing request")
pairing_request = PairingRequest(
token=self.token,
publicKey=self._public_key,
s2ClientNodeId=str(self.client_node_id),
s2ClientNodeDescription=self.node_description,
supportedProtocols=self.supported_protocols,
)
# Make pairing request
logger.info("Making pairing request")
status_code, response_text = self._make_https_request(
url=self.pairing_uri,
method="POST",
data=pairing_request.model_dump(exclude_none=True),
headers={"Content-Type": "application/json"},
)
logger.info('Pairing request response: %s %s', status_code, response_text)
# Parse response
if status_code != 200:
raise ValueError(
f"Pairing request failed with status {status_code}: {response_text}"
)
pairing_response = PairingResponse.model_validate(json.loads(response_text))
# Store for later use
self._pairing_response = pairing_response
self.store_connection_request_uri(str(pairing_response.requestConnectionUri))
return pairing_response
def request_connection(self) -> ConnectionDetails:
"""Request connection details from the server.
Returns:
ConnectionDetails: The connection details returned by the server
Raises:
ValueError: If connection request URI is not set or if the request fails
"""
if not self._connection_request_uri:
raise ValueError(
"Connection request URI not set. Call request_pairing first."
)
# Create connection request
connection_request = ConnectionRequest(
s2ClientNodeId=str(self.client_node_id),
supportedProtocols=self.supported_protocols,
)
# Make a POST request to the connection request URI
status_code, response_text = self._make_https_request(
url=self._connection_request_uri,
method="POST",
data=connection_request.model_dump(exclude_none=True),
headers={"Content-Type": "application/json"},
)
# Parse response
if status_code != 200:
raise ValueError(
f"Connection request failed with status {status_code}: {response_text}"
)
connection_details = ConnectionDetails.model_validate(json.loads(response_text))
# Handle relative WebSocket URI paths
if (
connection_details.connectionUri is not None
and not str(connection_details.connectionUri).startswith("ws://")
and not str(connection_details.connectionUri).startswith("wss://")
):
# If websocket address doesn't start with ws:// or wss:// assume it's relative to the pairing URI
if self.pairing_uri:
base_uri = self.pairing_uri
# Convert to WebSocket protocol and remove the requestPairing path
ws_base = (
base_uri.replace("http://", "ws://")
.replace("https://", "wss://")
.replace("requestPairing", "")
.rstrip("/")
)
# Combine with the relative path from connectionUri
relative_path = str(connection_details.connectionUri).lstrip("/")
# Create complete URL
full_ws_url = f"{ws_base}/{relative_path}"
try:
# Update the connection details with the new URL
connection_data = connection_details.model_dump()
# Replace the URI with the full WebSocket URL
connection_data["connectionUri"] = full_ws_url
# Recreate the ConnectionDetails object
connection_details = ConnectionDetails.model_validate(
connection_data
)
logger.info('Updated relative WebSocket URI to absolute: %s', full_ws_url)
except (ValueError, TypeError, KeyError) as e:
logger.info('Failed to update WebSocket URI: %s', e)
else:
# Log a warning but don't modify the URI if we can't create a proper absolute URI
logger.info('Received relative WebSocket URI but pairing_uri is not available to create absolute URL')
# Store for later use
self._connection_details = connection_details
return connection_details
@abc.abstractmethod
def solve_challenge(self, challenge: Optional[str] = None) -> str:
"""Solve the connection challenge using the public key.
If no challenge is provided, uses the challenge from connection_details.
The challenge is a JWE (JSON Web Encryption) that must be decrypted using
the client's public key, then encoded as a base64 string.
Args:
challenge: The challenge string from the server (optional)
Returns:
str: The solution to the challenge (base64 encoded decrypted challenge)
Raises:
ValueError: If no challenge is provided and connection_details is not set
ValueError: If the public key is not available
RuntimeError: If challenge decryption fails
"""
@abc.abstractmethod
def establish_secure_connection(self) -> Any:
"""Establish a secure connection to the server.
This method should be implemented by concrete subclasses to establish
a secure connection using the connection details and solved challenge.
Implementations needs to use WebSocket Secure.
Returns:
Any: A connection object or handler specific to the implementation
Raises:
ValueError: If connection details or solved challenge are not available
RuntimeError: If connection establishment fails
"""
@abc.abstractmethod
def close_connection(self) -> None:
"""Close the connection to the server.
This method should be implemented by concrete subclasses to properly
close the connection established by establish_secure_connection.
"""
@property
def pairing_details(self) -> Optional[PairingDetails]:
"""Get the stored pairing details."""
return self._pairing_details