Skip to content

Commit 1e04534

Browse files
committed
iSuch reactive! So stream!
1 parent 860b3a6 commit 1e04534

File tree

2 files changed

+125
-159
lines changed

2 files changed

+125
-159
lines changed

neo4j/connection.py

Lines changed: 113 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535

3636
DEFAULT_PORT = 7687
37+
DEFAULT_USER_AGENT = "neo4j-python/0.0"
3738

3839
# Signature bytes for each message type
3940
INIT = b"\x01" # 0000 0001 // INIT <user_agent>
@@ -76,16 +77,20 @@ def hex2(x):
7677
return hex(x)[2:].upper()
7778

7879

79-
class ChunkWriter(object):
80-
""" Writer for chunked data.
80+
class ChunkChannel(object):
81+
""" Reader/writer for chunked data.
82+
83+
.. note:: logs at DEBUG level
8184
"""
8285

8386
max_chunk_size = 65535
8487

85-
def __init__(self):
88+
def __init__(self, sock):
89+
self.socket = sock
8690
self.raw = BytesIO()
8791
self.output_buffer = []
8892
self.output_size = 0
93+
self._recv_buffer = b""
8994

9095
def write(self, b):
9196
""" Write some bytes, splitting into chunks if necessary.
@@ -106,7 +111,7 @@ def write(self, b):
106111
self.output_size = future_size
107112
b = b""
108113

109-
def flush(self, zero_chunk=False):
114+
def flush(self, end_of_message=False):
110115
""" Flush everything written since the last chunk to the
111116
stream, followed by a zero-chunk if required.
112117
"""
@@ -115,28 +120,71 @@ def flush(self, zero_chunk=False):
115120
lines = [struct_pack(">H", self.output_size)] + output_buffer
116121
else:
117122
lines = []
118-
if zero_chunk:
123+
if end_of_message:
119124
lines.append(b"\x00\x00")
120125
if lines:
121126
self.raw.writelines(lines)
122127
self.raw.flush()
123128
del output_buffer[:]
124129
self.output_size = 0
125130

126-
def to_bytes(self):
127-
""" Extract the written data as bytes.
131+
def send(self):
132+
""" Send all queued messages to the server.
128133
"""
129-
return self.raw.getvalue()
134+
data = self.raw.getvalue()
135+
if __debug__:
136+
log_debug("C: %s", ":".join(map(hex2, data)))
137+
self.socket.sendall(data)
130138

131-
def reset(self):
132-
""" Reset the stream.
139+
self.raw.seek(self.raw.truncate(0))
140+
141+
def _recv(self, size):
142+
# If data is needed, keep reading until all bytes have been received
143+
remaining = size - len(self._recv_buffer)
144+
ready_to_read = None
145+
while remaining > 0:
146+
# Read up to the required amount remaining
147+
b = self.socket.recv(8192)
148+
if b:
149+
if __debug__: log_debug("S: %s", ":".join(map(hex2, b)))
150+
else:
151+
if ready_to_read is not None:
152+
raise ProtocolError("Server closed connection")
153+
remaining -= len(b)
154+
self._recv_buffer += b
155+
156+
# If more is required, wait for available network data
157+
if remaining > 0:
158+
ready_to_read, _, _ = select((self.socket,), (), (), 0)
159+
while not ready_to_read:
160+
ready_to_read, _, _ = select((self.socket,), (), (), 0)
161+
162+
# Split off the amount of data required and keep the rest in the buffer
163+
data, self._recv_buffer = self._recv_buffer[:size], self._recv_buffer[size:]
164+
return data
165+
166+
def chunk_reader(self):
167+
chunk_size = -1
168+
while chunk_size != 0:
169+
chunk_header = self._recv(2)
170+
chunk_size, = struct_unpack_from(">H", chunk_header)
171+
if chunk_size > 0:
172+
data = self._recv(chunk_size)
173+
yield data
174+
175+
def close(self):
176+
""" Shut down and close the connection.
133177
"""
134-
self.flush()
135-
raw = self.raw
136-
return raw.seek(raw.truncate(0))
178+
if __debug__: log_info("~~ [CLOSE]")
179+
socket = self.socket
180+
socket.shutdown(SHUT_RDWR)
181+
socket.close()
137182

138183

139-
class ResponseSubscriber(object):
184+
class Response(object):
185+
""" Subscriber object for a full response (zero or
186+
more detail messages followed by one summary message).
187+
"""
140188

141189
def __init__(self, connection):
142190
self.connection = connection
@@ -160,151 +208,92 @@ def deliver(self, messages):
160208
if signature in SUMMARY:
161209
self.__complete = True
162210
if signature == FAILURE:
163-
self.connection.ack_failure()
211+
self.ack_failure()
164212

165213
def consume(self):
166-
self.connection.receive_all(self)
214+
self.connection.fetch_all(self)
215+
216+
def ack_failure(self):
217+
""" Queue an acknowledgement for a previous failure.
218+
"""
219+
220+
def on_failure(metadata):
221+
raise ProtocolError("Could not acknowledge failure")
222+
223+
subscriber = Response(self)
224+
subscriber.on_failure = on_failure
225+
self.connection.append(ACK_FAILURE, response=subscriber)
167226

168227

169228
class Connection(object):
170229
""" Server connection through which all protocol messages
171230
are sent and received. This class is designed for protocol
172231
version 1.
232+
233+
.. note:: logs at INFO level
173234
"""
174235

175236
def __init__(self, sock, **config):
176-
self.socket = sock
177-
self.raw = ChunkWriter()
178-
self.packer = Packer(self.raw)
179-
self.inbox = deque()
180-
self.outbox = deque()
181-
self._recv_buffer = b""
237+
self.channel = ChunkChannel(sock)
238+
self.packer = Packer(self.channel)
239+
self.responses = deque()
182240

183-
def append(self, signature, fields=(), subscriber=None):
184-
""" Add a message to the outgoing queue.
185-
"""
186-
self.outbox.append((signature, fields))
187-
self.inbox.append(subscriber)
241+
# Determine the user agent and ensure it is a Unicode value
242+
user_agent = config.get("user-agent", DEFAULT_USER_AGENT)
243+
if isinstance(user_agent, bytes):
244+
user_agent = user_agent.decode("UTF-8")
188245

189-
def send(self):
190-
""" Send all queued messages to the server.
191-
"""
246+
def on_failure(metadata):
247+
raise ProtocolError("Initialisation failed")
192248

193-
# Shortcuts to avoid too many dots
194-
raw = self.raw
195-
packer = self.packer
196-
pack_struct_header = packer.pack_struct_header
197-
pack = packer.pack
198-
flush = raw.flush
199-
200-
for signature, fields in self.outbox:
201-
pack_struct_header(len(fields), signature)
202-
for field in fields:
203-
pack(field)
204-
flush(zero_chunk=True)
205-
206-
data = raw.to_bytes()
207-
if __debug__: log_debug("C: %s", ":".join(map(hex2, data)))
208-
self.socket.sendall(data)
249+
response = Response(self)
250+
response.on_failure = on_failure
209251

210-
raw.reset()
211-
self.outbox.clear()
252+
self.append(INIT, (user_agent,), response=response)
253+
self.send()
254+
response.consume()
212255

213-
def _recv(self, size):
214-
# If data is needed, keep reading until all bytes have been received
215-
remaining = size - len(self._recv_buffer)
216-
ready_to_read = None
217-
while remaining > 0:
218-
# Read up to the required amount remaining
219-
b = self.socket.recv(8192)
220-
if b:
221-
if __debug__: log_debug("S: %s", ":".join(map(hex2, b)))
222-
else:
223-
if ready_to_read is not None:
224-
raise ProtocolError("Server closed connection")
225-
remaining -= len(b)
226-
self._recv_buffer += b
256+
def append(self, signature, fields=(), response=None):
257+
""" Add a message to the outgoing queue.
258+
"""
259+
if __debug__:
260+
log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields)))
227261

228-
# If more is required, wait for available network data
229-
if remaining > 0:
230-
ready_to_read, _, _ = select((self.socket,), (), (), 0)
231-
while not ready_to_read:
232-
ready_to_read, _, _ = select((self.socket,), (), (), 0)
262+
self.packer.pack_struct_header(len(fields), signature)
263+
for field in fields:
264+
self.packer.pack(field)
265+
self.channel.flush(end_of_message=True)
266+
self.responses.append(response)
233267

234-
# Split off the amount of data required and keep the rest in the buffer
235-
data, self._recv_buffer = self._recv_buffer[:size], self._recv_buffer[size:]
236-
return data
268+
def send(self):
269+
""" Send all queued messages to the server.
270+
"""
271+
self.channel.send()
237272

238-
def receive_next(self):
273+
def fetch_next(self):
239274
""" Receive exactly one message from the server.
240275
"""
241276
raw = BytesIO()
242277
unpack = Unpacker(raw).unpack
243-
244-
# Receive chunks of data until chunk_size == 0
245-
chunk_size = None
246-
while chunk_size != 0:
247-
if chunk_size is None:
248-
# Chunk header
249-
data = self._recv(2)
250-
chunk_size, = struct_unpack_from(">H", data)
251-
else:
252-
# Chunk content
253-
data = self._recv(chunk_size)
254-
raw.write(data)
255-
chunk_size = None
278+
raw.writelines(self.channel.chunk_reader())
256279

257280
# Unpack the message from the raw byte stream into the inbox
258281
raw.seek(0)
259-
response = self.inbox[0]
282+
response = self.responses[0]
260283
response.deliver(unpack())
261284
if response.complete:
262-
self.inbox.popleft()
285+
self.responses.popleft()
263286
raw.close()
264287

265-
def receive_all(self, response):
266-
receive_next = self.receive_next
288+
def fetch_all(self, response):
289+
fetch_next = self.fetch_next
267290
while not response.complete:
268-
receive_next()
269-
270-
def init(self, user_agent):
271-
""" Initialise a connection with a user agent string.
272-
"""
273-
274-
# Ensure the user agent is a Unicode value
275-
if isinstance(user_agent, bytes):
276-
user_agent = user_agent.decode("UTF-8")
277-
278-
if __debug__: log_info("C: INIT %r", user_agent)
279-
280-
def on_failure(metadata):
281-
raise ProtocolError("Initialisation failed")
282-
283-
subscriber = ResponseSubscriber(self)
284-
subscriber.on_failure = on_failure
285-
self.append(INIT, (user_agent,), subscriber=subscriber)
286-
self.send()
287-
subscriber.consume()
288-
289-
def ack_failure(self):
290-
""" Queue an acknowledgement for a previous failure.
291-
"""
292-
if __debug__: log_info("C: ACK_FAILURE")
293-
294-
def on_failure(metadata):
295-
raise ProtocolError("Could not acknowledge failure")
296-
297-
subscriber = ResponseSubscriber(self)
298-
subscriber.on_failure = on_failure
299-
self.append(ACK_FAILURE, subscriber=subscriber)
291+
fetch_next()
300292

301293
def close(self):
302294
""" Shut down and close the connection.
303295
"""
304-
if __debug__: log_info("~~ [CLOSE]")
305-
socket = self.socket
306-
socket.shutdown(SHUT_RDWR)
307-
socket.close()
296+
self.channel.close()
308297

309298

310299
def connect(host, port=None, **config):

0 commit comments

Comments
 (0)