Skip to content

Commit f4c68b9

Browse files
authored
Merge pull request #55 from BastienTr/MBF-FD
Fix a bug & new MIMO detector & Refactoring
2 parents 463b8ad + ee25547 commit f4c68b9

File tree

9 files changed

+538
-235
lines changed

9 files changed

+538
-235
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Available Features
5454
- OFDM Tx/Rx signal processing
5555
- MIMO Maximum Likelihood (ML) Detection.
5656
- MIMO K-best Schnorr-Euchner Detection.
57+
- MIMO Best-First Detection.
5758
- Convert channel matrix to Bit-level representation.
5859
- Computation of LogLikelihood ratio using max-log approximation.
5960

commpy/channelcoding/convcode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __init__(self, memory, g_matrix, feedback = None, code_type = 'default'):
152152

153153
output_generator_array[l] = generator_array[0]
154154
if l == 0:
155-
feedback_array = (dec2bitarray(feedback, memory[l]) * shift_register[0:memory[l]]).sum()
155+
feedback_array = (dec2bitarray(feedback, memory[l])[1:] * shift_register[0:memory[l]]).sum()
156156
shift_register[1:memory[l]] = \
157157
shift_register[0:memory[l] - 1]
158158
shift_register[0] = (dec2bitarray(current_input,

commpy/channelcoding/ldpc.py

Lines changed: 62 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
def build_matrix(ldpc_code_params):
1414
"""
1515
Build the parity check and generator matrices from parameters dictionary and add the result in this dictionary.
16+
Generator matrix is valid only for triangular systematic LDPC codes.
1617
1718
Parameters
1819
----------
@@ -140,35 +141,6 @@ def get_ldpc_code_params(ldpc_design_filename, compute_matrix=False):
140141
return ldpc_code_params
141142

142143

143-
def sum_product_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
144-
vnode_msgs, cnode_vnode_map, max_cnode_deg, max_vnode_deg):
145-
146-
start_idx = cnode_idx*max_cnode_deg
147-
offset = cnode_deg_list[cnode_idx]
148-
vnode_list = cnode_adj_list[start_idx:start_idx+offset]
149-
vnode_list_msgs_tanh = np.tanh(vnode_msgs[vnode_list*max_vnode_deg +
150-
cnode_vnode_map[start_idx:start_idx+offset]] / 2.0)
151-
msg_prod = vnode_list_msgs_tanh.prod(0)
152-
153-
# Compute messages on outgoing edges using the incoming message product
154-
np.clip(2 * np.arctanh(msg_prod / vnode_list_msgs_tanh),
155-
-_llr_max, _llr_max, cnode_msgs[start_idx:start_idx+offset])
156-
157-
158-
def min_sum_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
159-
vnode_msgs, cnode_vnode_map, max_cnode_deg, max_vnode_deg):
160-
161-
start_idx = cnode_idx*max_cnode_deg
162-
offset = cnode_deg_list[cnode_idx]
163-
vnode_list = cnode_adj_list[start_idx:start_idx+offset]
164-
vnode_list_msgs = vnode_msgs[vnode_list*max_vnode_deg + cnode_vnode_map[start_idx:start_idx+offset]]
165-
166-
# Compute messages on outgoing edges using the incoming messages
167-
for i in range(start_idx, start_idx+offset):
168-
vnode_list_msgs_excluded = vnode_list_msgs[np.arange(len(vnode_list_msgs)) != i - start_idx, :]
169-
cnode_msgs[i] = np.sign(vnode_list_msgs_excluded).prod(0) * np.abs(vnode_list_msgs_excluded).min(0)
170-
171-
172144
def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
173145
"""
174146
LDPC Decoder using Belief Propagation (BP). If several blocks are provided, they are all decoded at once.
@@ -213,79 +185,73 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters):
213185
# Clip LLRs
214186
llr_vec.clip(-_llr_max, _llr_max, llr_vec)
215187

216-
n_cnodes = ldpc_code_params['n_cnodes']
217-
n_vnodes = ldpc_code_params['n_vnodes']
218-
max_cnode_deg = ldpc_code_params['max_cnode_deg']
219-
max_vnode_deg = ldpc_code_params['max_vnode_deg']
220-
cnode_adj_list = ldpc_code_params['cnode_adj_list']
221-
cnode_vnode_map = ldpc_code_params['cnode_vnode_map']
222-
vnode_adj_list = ldpc_code_params['vnode_adj_list']
223-
vnode_cnode_map = ldpc_code_params['vnode_cnode_map']
224-
cnode_deg_list = ldpc_code_params['cnode_deg_list']
225-
vnode_deg_list = ldpc_code_params['vnode_deg_list']
226-
227-
# Handling multi-block situations
228-
n_blocks = llr_vec.size // n_vnodes
229-
llr_vec = llr_vec.reshape(-1, n_blocks, order='F')
230-
231-
dec_word = np.empty_like(llr_vec, bool)
232-
out_llrs = np.empty_like(llr_vec)
233-
cnode_msgs = np.empty((n_cnodes * max_cnode_deg, n_blocks))
234-
235-
if decoder_algorithm == 'SPA':
236-
check_node_update = sum_product_update
237-
elif decoder_algorithm == 'MSA':
238-
check_node_update = min_sum_update
239-
else:
240-
raise NameError('Please input a valid decoder_algorithm string (meanning "SPA" or "MSA").')
241-
242-
# Initialize vnode messages with the LLR values received
243-
vnode_msgs = llr_vec.repeat(max_vnode_deg, 0)
244-
245-
# Main loop of Belief Propagation (BP) decoding iterations
246-
for iter_cnt in range(n_iters):
247-
248-
continue_flag = False
249-
250-
# Check Node Update
251-
for cnode_idx in range(n_cnodes):
252-
253-
check_node_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs,
254-
vnode_msgs, cnode_vnode_map, max_cnode_deg, max_vnode_deg)
255-
256-
# Variable Node Update
257-
for vnode_idx in range(n_vnodes):
258-
259-
# Compute sum of all incoming messages at the variable node
260-
start_idx = vnode_idx*max_vnode_deg
261-
offset = vnode_deg_list[vnode_idx]
262-
cnode_list = vnode_adj_list[start_idx:start_idx+offset]
263-
cnode_list_msgs = cnode_msgs[cnode_list*max_cnode_deg + vnode_cnode_map[start_idx:start_idx+offset]]
264-
msg_sum = cnode_list_msgs.sum(0)
265-
266-
# Compute messages on outgoing edges using the incoming message sum
267-
vnode_msgs[start_idx:start_idx+offset] = llr_vec[vnode_idx] + msg_sum - cnode_list_msgs
188+
# Build parity_check_matrix if required
189+
if ldpc_code_params.get('parity_check_matrix') is None:
190+
build_matrix(ldpc_code_params)
268191

269-
# Update output LLRs and decoded word
270-
out_llrs[vnode_idx] = llr_vec[vnode_idx] + msg_sum
192+
# Initialization
193+
dec_word = np.signbit(llr_vec)
194+
out_llrs = llr_vec.copy()
195+
parity_check_matrix = ldpc_code_params['parity_check_matrix'].astype(float).tocoo()
271196

272-
np.signbit(out_llrs, out=dec_word)
197+
for i_start in range(0, llr_vec.size, ldpc_code_params['n_vnodes']):
198+
i_stop = i_start + ldpc_code_params['n_vnodes']
199+
message_matrix = parity_check_matrix.multiply(llr_vec[i_start:i_stop])
273200

274-
# Compute early termination using parity check matrix
275-
for cnode_idx in range(n_cnodes):
276-
start_idx = cnode_idx * max_cnode_deg
277-
offset = cnode_deg_list[cnode_idx]
278-
parity_sum = np.logical_xor.reduce(dec_word[cnode_adj_list[start_idx:start_idx + offset]])
201+
# Main loop of Belief Propagation (BP) decoding iterations
202+
for iter_cnt in range(n_iters):
279203

280-
if parity_sum.any():
281-
continue_flag = True
204+
# Compute early termination using parity check matrix
205+
if np.all(ldpc_code_params['parity_check_matrix'].multiply(dec_word[i_start:i_stop]).sum(1) % 2 == 0):
282206
break
283207

284-
# Stop iterations
285-
if not continue_flag:
286-
break
287-
288-
return dec_word.squeeze().astype(np.int8), out_llrs.squeeze()
208+
# Check Node Update
209+
if decoder_algorithm == 'SPA':
210+
# Compute incoming messages
211+
message_matrix.data *= .5
212+
np.tanh(message_matrix.data, out=message_matrix.data)
213+
214+
# Runtime Warnings are expected when llr = 0. No warn should be raised as this case are expected.
215+
with np.errstate(divide='ignore', invalid='ignore'):
216+
# Compute product as exponent of the sum of logarithm
217+
log2_msg_matrix = message_matrix.astype(complex).copy()
218+
np.log2(message_matrix.data.astype(complex), out=log2_msg_matrix.data)
219+
msg_products = np.exp2(log2_msg_matrix.sum(1)).real
220+
221+
# Compute outgoing messages
222+
message_matrix.data = 1 / message_matrix.data
223+
message_matrix = message_matrix.multiply(msg_products)
224+
message_matrix.data.clip(-1, 1, message_matrix.data)
225+
np.arctanh(message_matrix.data, out=message_matrix.data)
226+
message_matrix.data *= 2
227+
message_matrix.data.clip(-_llr_max, _llr_max, message_matrix.data)
228+
229+
elif decoder_algorithm == 'MSA':
230+
message_matrix = message_matrix.tocsr()
231+
for row_idx in range(message_matrix.shape[0]):
232+
begin_row = message_matrix.indptr[row_idx]
233+
end_row = message_matrix.indptr[row_idx+1]
234+
row_data = message_matrix.data[begin_row:end_row].copy()
235+
indexes = np.arange(len(row_data))
236+
for j, i in enumerate(range(begin_row, end_row)):
237+
other_val = row_data[indexes != j]
238+
message_matrix.data[i] = np.sign(other_val).prod() * np.abs(other_val).min()
239+
else:
240+
raise NameError('Please input a valid decoder_algorithm string (meanning "SPA" or "MSA").')
241+
242+
# Variable Node Update
243+
msg_sum = np.array(message_matrix.sum(0)).squeeze()
244+
message_matrix *= -1
245+
message_matrix += parity_check_matrix.multiply(msg_sum + llr_vec[i_start:i_stop])
246+
247+
out_llrs = msg_sum + llr_vec[i_start:i_stop]
248+
np.signbit(out_llrs, out=dec_word[i_start:i_stop])
249+
250+
# Reformat outputs
251+
n_blocks = llr_vec.size // ldpc_code_params['n_vnodes']
252+
dec_word = dec_word.reshape(-1, n_blocks, order='F').squeeze().astype(np.int8)
253+
out_llrs = out_llrs.reshape(-1, n_blocks, order='F').squeeze()
254+
return dec_word, out_llrs
289255

290256

291257
def write_ldpc_params(parity_check_matrix, file_path):

commpy/channelcoding/tests/test_ldpc.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,32 @@ def test_ldpc_bp_decode(self):
3838
tx_codeword = zeros(N, int)
3939
ldpcbp_iters = 100
4040

41-
fer_array_ref = array([200.0/1000, 200.0/2000])
42-
fer_array_test = zeros(len(snr_list))
41+
for decoder_algorithm in ('MSA', 'SPA'):
42+
fer_array_ref = array((.2, .1))
43+
fer_array_test = zeros(len(snr_list))
4344

44-
for idx, ebno in enumerate(snr_list):
45+
for idx, ebno in enumerate(snr_list):
4546

46-
noise_std = 1/sqrt((10**(ebno/10.0))*rate*2/Es)
47-
fer_cnt_bp = 0
47+
noise_std = 1/sqrt((10**(ebno/10.0))*rate*2/Es)
48+
fer_cnt_bp = 0
4849

49-
for iter_cnt in range(niters):
50+
for iter_cnt in range(niters):
5051

51-
awgn_array = noise_std * randn(N)
52-
rx_word = 1-(2*tx_codeword) + awgn_array
53-
rx_llrs = 2.0*rx_word/(noise_std**2)
52+
awgn_array = noise_std * randn(N)
53+
rx_word = 1-(2*tx_codeword) + awgn_array
54+
rx_llrs = 2.0*rx_word/(noise_std**2)
5455

55-
[dec_word, out_llrs] = ldpc_bp_decode(rx_llrs, ldpc_code_params, 'SPA', ldpcbp_iters)
56+
[dec_word, _] = ldpc_bp_decode(rx_llrs, ldpc_code_params, decoder_algorithm, ldpcbp_iters)
5657

57-
num_bit_errors = hamming_dist(tx_codeword, dec_word.reshape(-1))
58-
if num_bit_errors > 0:
59-
fer_cnt_bp += 1
58+
if hamming_dist(tx_codeword, dec_word.reshape(-1)):
59+
fer_cnt_bp += 1
6060

61-
if fer_cnt_bp >= 200:
62-
fer_array_test[idx] = float(fer_cnt_bp) / (iter_cnt + 1) / n_blocks
63-
break
61+
if fer_cnt_bp >= 50:
62+
fer_array_test[idx] = float(fer_cnt_bp) / (iter_cnt + 1) / n_blocks
63+
break
6464

65-
assert_allclose(fer_array_test, fer_array_ref, rtol=.5, atol=0)
65+
assert_allclose(fer_array_test, fer_array_ref, rtol=.5, atol=0,
66+
err_msg=decoder_algorithm + ' algorithm does not perform as expected.')
6667

6768
def test_write_ldpc_params(self):
6869
with TemporaryDirectory() as tmp_dir:

0 commit comments

Comments
 (0)