|
13 | 13 | def build_matrix(ldpc_code_params): |
14 | 14 | """ |
15 | 15 | Build the parity check and generator matrices from parameters dictionary and add the result in this dictionary. |
| 16 | + Generator matrix is valid only for triangular systematic LDPC codes. |
16 | 17 |
|
17 | 18 | Parameters |
18 | 19 | ---------- |
@@ -140,35 +141,6 @@ def get_ldpc_code_params(ldpc_design_filename, compute_matrix=False): |
140 | 141 | return ldpc_code_params |
141 | 142 |
|
142 | 143 |
|
143 | | -def sum_product_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs, |
144 | | - vnode_msgs, cnode_vnode_map, max_cnode_deg, max_vnode_deg): |
145 | | - |
146 | | - start_idx = cnode_idx*max_cnode_deg |
147 | | - offset = cnode_deg_list[cnode_idx] |
148 | | - vnode_list = cnode_adj_list[start_idx:start_idx+offset] |
149 | | - vnode_list_msgs_tanh = np.tanh(vnode_msgs[vnode_list*max_vnode_deg + |
150 | | - cnode_vnode_map[start_idx:start_idx+offset]] / 2.0) |
151 | | - msg_prod = vnode_list_msgs_tanh.prod(0) |
152 | | - |
153 | | - # Compute messages on outgoing edges using the incoming message product |
154 | | - np.clip(2 * np.arctanh(msg_prod / vnode_list_msgs_tanh), |
155 | | - -_llr_max, _llr_max, cnode_msgs[start_idx:start_idx+offset]) |
156 | | - |
157 | | - |
158 | | -def min_sum_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs, |
159 | | - vnode_msgs, cnode_vnode_map, max_cnode_deg, max_vnode_deg): |
160 | | - |
161 | | - start_idx = cnode_idx*max_cnode_deg |
162 | | - offset = cnode_deg_list[cnode_idx] |
163 | | - vnode_list = cnode_adj_list[start_idx:start_idx+offset] |
164 | | - vnode_list_msgs = vnode_msgs[vnode_list*max_vnode_deg + cnode_vnode_map[start_idx:start_idx+offset]] |
165 | | - |
166 | | - # Compute messages on outgoing edges using the incoming messages |
167 | | - for i in range(start_idx, start_idx+offset): |
168 | | - vnode_list_msgs_excluded = vnode_list_msgs[np.arange(len(vnode_list_msgs)) != i - start_idx, :] |
169 | | - cnode_msgs[i] = np.sign(vnode_list_msgs_excluded).prod(0) * np.abs(vnode_list_msgs_excluded).min(0) |
170 | | - |
171 | | - |
172 | 144 | def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters): |
173 | 145 | """ |
174 | 146 | LDPC Decoder using Belief Propagation (BP). If several blocks are provided, they are all decoded at once. |
@@ -213,79 +185,73 @@ def ldpc_bp_decode(llr_vec, ldpc_code_params, decoder_algorithm, n_iters): |
213 | 185 | # Clip LLRs |
214 | 186 | llr_vec.clip(-_llr_max, _llr_max, llr_vec) |
215 | 187 |
|
216 | | - n_cnodes = ldpc_code_params['n_cnodes'] |
217 | | - n_vnodes = ldpc_code_params['n_vnodes'] |
218 | | - max_cnode_deg = ldpc_code_params['max_cnode_deg'] |
219 | | - max_vnode_deg = ldpc_code_params['max_vnode_deg'] |
220 | | - cnode_adj_list = ldpc_code_params['cnode_adj_list'] |
221 | | - cnode_vnode_map = ldpc_code_params['cnode_vnode_map'] |
222 | | - vnode_adj_list = ldpc_code_params['vnode_adj_list'] |
223 | | - vnode_cnode_map = ldpc_code_params['vnode_cnode_map'] |
224 | | - cnode_deg_list = ldpc_code_params['cnode_deg_list'] |
225 | | - vnode_deg_list = ldpc_code_params['vnode_deg_list'] |
226 | | - |
227 | | - # Handling multi-block situations |
228 | | - n_blocks = llr_vec.size // n_vnodes |
229 | | - llr_vec = llr_vec.reshape(-1, n_blocks, order='F') |
230 | | - |
231 | | - dec_word = np.empty_like(llr_vec, bool) |
232 | | - out_llrs = np.empty_like(llr_vec) |
233 | | - cnode_msgs = np.empty((n_cnodes * max_cnode_deg, n_blocks)) |
234 | | - |
235 | | - if decoder_algorithm == 'SPA': |
236 | | - check_node_update = sum_product_update |
237 | | - elif decoder_algorithm == 'MSA': |
238 | | - check_node_update = min_sum_update |
239 | | - else: |
240 | | - raise NameError('Please input a valid decoder_algorithm string (meanning "SPA" or "MSA").') |
241 | | - |
242 | | - # Initialize vnode messages with the LLR values received |
243 | | - vnode_msgs = llr_vec.repeat(max_vnode_deg, 0) |
244 | | - |
245 | | - # Main loop of Belief Propagation (BP) decoding iterations |
246 | | - for iter_cnt in range(n_iters): |
247 | | - |
248 | | - continue_flag = False |
249 | | - |
250 | | - # Check Node Update |
251 | | - for cnode_idx in range(n_cnodes): |
252 | | - |
253 | | - check_node_update(cnode_idx, cnode_adj_list, cnode_deg_list, cnode_msgs, |
254 | | - vnode_msgs, cnode_vnode_map, max_cnode_deg, max_vnode_deg) |
255 | | - |
256 | | - # Variable Node Update |
257 | | - for vnode_idx in range(n_vnodes): |
258 | | - |
259 | | - # Compute sum of all incoming messages at the variable node |
260 | | - start_idx = vnode_idx*max_vnode_deg |
261 | | - offset = vnode_deg_list[vnode_idx] |
262 | | - cnode_list = vnode_adj_list[start_idx:start_idx+offset] |
263 | | - cnode_list_msgs = cnode_msgs[cnode_list*max_cnode_deg + vnode_cnode_map[start_idx:start_idx+offset]] |
264 | | - msg_sum = cnode_list_msgs.sum(0) |
265 | | - |
266 | | - # Compute messages on outgoing edges using the incoming message sum |
267 | | - vnode_msgs[start_idx:start_idx+offset] = llr_vec[vnode_idx] + msg_sum - cnode_list_msgs |
| 188 | + # Build parity_check_matrix if required |
| 189 | + if ldpc_code_params.get('parity_check_matrix') is None: |
| 190 | + build_matrix(ldpc_code_params) |
268 | 191 |
|
269 | | - # Update output LLRs and decoded word |
270 | | - out_llrs[vnode_idx] = llr_vec[vnode_idx] + msg_sum |
| 192 | + # Initialization |
| 193 | + dec_word = np.signbit(llr_vec) |
| 194 | + out_llrs = llr_vec.copy() |
| 195 | + parity_check_matrix = ldpc_code_params['parity_check_matrix'].astype(float).tocoo() |
271 | 196 |
|
272 | | - np.signbit(out_llrs, out=dec_word) |
| 197 | + for i_start in range(0, llr_vec.size, ldpc_code_params['n_vnodes']): |
| 198 | + i_stop = i_start + ldpc_code_params['n_vnodes'] |
| 199 | + message_matrix = parity_check_matrix.multiply(llr_vec[i_start:i_stop]) |
273 | 200 |
|
274 | | - # Compute early termination using parity check matrix |
275 | | - for cnode_idx in range(n_cnodes): |
276 | | - start_idx = cnode_idx * max_cnode_deg |
277 | | - offset = cnode_deg_list[cnode_idx] |
278 | | - parity_sum = np.logical_xor.reduce(dec_word[cnode_adj_list[start_idx:start_idx + offset]]) |
| 201 | + # Main loop of Belief Propagation (BP) decoding iterations |
| 202 | + for iter_cnt in range(n_iters): |
279 | 203 |
|
280 | | - if parity_sum.any(): |
281 | | - continue_flag = True |
| 204 | + # Compute early termination using parity check matrix |
| 205 | + if np.all(ldpc_code_params['parity_check_matrix'].multiply(dec_word[i_start:i_stop]).sum(1) % 2 == 0): |
282 | 206 | break |
283 | 207 |
|
284 | | - # Stop iterations |
285 | | - if not continue_flag: |
286 | | - break |
287 | | - |
288 | | - return dec_word.squeeze().astype(np.int8), out_llrs.squeeze() |
| 208 | + # Check Node Update |
| 209 | + if decoder_algorithm == 'SPA': |
| 210 | + # Compute incoming messages |
| 211 | + message_matrix.data *= .5 |
| 212 | + np.tanh(message_matrix.data, out=message_matrix.data) |
| 213 | + |
| 214 | + # Runtime Warnings are expected when llr = 0. No warn should be raised as this case are expected. |
| 215 | + with np.errstate(divide='ignore', invalid='ignore'): |
| 216 | + # Compute product as exponent of the sum of logarithm |
| 217 | + log2_msg_matrix = message_matrix.astype(complex).copy() |
| 218 | + np.log2(message_matrix.data.astype(complex), out=log2_msg_matrix.data) |
| 219 | + msg_products = np.exp2(log2_msg_matrix.sum(1)).real |
| 220 | + |
| 221 | + # Compute outgoing messages |
| 222 | + message_matrix.data = 1 / message_matrix.data |
| 223 | + message_matrix = message_matrix.multiply(msg_products) |
| 224 | + message_matrix.data.clip(-1, 1, message_matrix.data) |
| 225 | + np.arctanh(message_matrix.data, out=message_matrix.data) |
| 226 | + message_matrix.data *= 2 |
| 227 | + message_matrix.data.clip(-_llr_max, _llr_max, message_matrix.data) |
| 228 | + |
| 229 | + elif decoder_algorithm == 'MSA': |
| 230 | + message_matrix = message_matrix.tocsr() |
| 231 | + for row_idx in range(message_matrix.shape[0]): |
| 232 | + begin_row = message_matrix.indptr[row_idx] |
| 233 | + end_row = message_matrix.indptr[row_idx+1] |
| 234 | + row_data = message_matrix.data[begin_row:end_row].copy() |
| 235 | + indexes = np.arange(len(row_data)) |
| 236 | + for j, i in enumerate(range(begin_row, end_row)): |
| 237 | + other_val = row_data[indexes != j] |
| 238 | + message_matrix.data[i] = np.sign(other_val).prod() * np.abs(other_val).min() |
| 239 | + else: |
| 240 | + raise NameError('Please input a valid decoder_algorithm string (meanning "SPA" or "MSA").') |
| 241 | + |
| 242 | + # Variable Node Update |
| 243 | + msg_sum = np.array(message_matrix.sum(0)).squeeze() |
| 244 | + message_matrix *= -1 |
| 245 | + message_matrix += parity_check_matrix.multiply(msg_sum + llr_vec[i_start:i_stop]) |
| 246 | + |
| 247 | + out_llrs = msg_sum + llr_vec[i_start:i_stop] |
| 248 | + np.signbit(out_llrs, out=dec_word[i_start:i_stop]) |
| 249 | + |
| 250 | + # Reformat outputs |
| 251 | + n_blocks = llr_vec.size // ldpc_code_params['n_vnodes'] |
| 252 | + dec_word = dec_word.reshape(-1, n_blocks, order='F').squeeze().astype(np.int8) |
| 253 | + out_llrs = out_llrs.reshape(-1, n_blocks, order='F').squeeze() |
| 254 | + return dec_word, out_llrs |
289 | 255 |
|
290 | 256 |
|
291 | 257 | def write_ldpc_params(parity_check_matrix, file_path): |
|
0 commit comments