Skip to content

Commit 60ba825

Browse files
committed
Improves the code, such as adding missing "const"
This is a minor update to perform minor tidying of the code. The largest change is the addition of missing "const" qualifiers in the C and CUDA code.
1 parent e1a2514 commit 60ba825

20 files changed

Lines changed: 1328 additions & 1249 deletions

src/deepwave/acoustic.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ __declspec(dllexport)
615615
// Save snapshots
616616
#define SAVE_SNAPSHOT(name, grad_cond) \
617617
if (grad_cond) { \
618-
int64_t step_idx = t / step_ratio; \
618+
int64_t const step_idx = t / step_ratio; \
619619
storage_save_snapshot_cpu( \
620620
name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \
621621
storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \
@@ -952,7 +952,7 @@ __declspec(dllexport)
952952
: 0)) * \
953953
(int64_t)shot_bytes_comp; \
954954
if ((grad_cond) && ((t % step_ratio) == 0)) { \
955-
int64_t step_idx = t / step_ratio; \
955+
int64_t const step_idx = t / step_ratio; \
956956
storage_load_snapshot_cpu( \
957957
(void *)name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \
958958
storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \

src/deepwave/acoustic.cu

Lines changed: 299 additions & 279 deletions
Large diffs are not rendered by default.

src/deepwave/acoustic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def forward(
854854
receiver_amplitudes[i].resize_(nt, n_shots, n_per_shot)
855855
receiver_amplitudes[i].fill_(0)
856856

857-
stream = 0
857+
stream: Union[int, torch.Stream] = 0
858858
if is_cuda:
859859
aux = models[0].get_device()
860860
stream = torch.cuda.current_stream(aux)
@@ -1068,7 +1068,7 @@ def backward(ctx: Any, *args: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...
10681068

10691069
model_batched = [m.ndim == ndim + 1 and m.shape[0] > 1 for m in models[:2]]
10701070

1071-
stream = 0
1071+
stream: Union[int, torch.Stream] = 0
10721072
if is_cuda:
10731073
aux = models[0].get_device()
10741074
stream = torch.cuda.current_stream(aux)

src/deepwave/backend_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_scalar_born_forward_template(ndim: int) -> List[Any]:
114114
args += [c_void_p] * 4 # w_store_1a, w_store_1b, w_store_2, w_store_3
115115
args += [c_void_p] # w_filenames
116116
args += [c_void_p] * 4 # wsc_store_1a, w_store_1b, wsc_store_2, wsc_store_3
117-
args += [c_void_p] #sc w_filenames
117+
args += [c_void_p] # sc w_filenames
118118
args += [c_void_p] * 2 # r, rsc
119119
args += [c_void_p] * (3 * ndim) # a, b, dbdx
120120
args += [c_void_p] * 3 # sources_i, receivers_i, receiverssc_i

src/deepwave/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999

100100
def get_ptrs(
101101
self,
102-
) -> Tuple[int, int, int, int, int, int, ctypes.Array[ctypes.c_char_p]]:
102+
) -> Tuple[int, int, int, int, ctypes.Array[ctypes.c_char_p]]:
103103
"""Return pointers to the storage and filenames array."""
104104
return (
105105
self.store_1a.data_ptr(),

src/deepwave/elastic.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,13 +341,14 @@ static inline void add_pressure(
341341

342342
static inline void record_pressure(
343343
#if DW_NDIM >= 3
344-
DW_DTYPE const *__restrict sigmazz,
344+
DW_DTYPE const *__restrict const sigmazz,
345345
#endif
346346
#if DW_NDIM >= 2
347-
DW_DTYPE const *__restrict sigmayy,
347+
DW_DTYPE const *__restrict const sigmayy,
348348
#endif
349-
DW_DTYPE const *__restrict sigmaxx, int64_t const *__restrict locations,
350-
DW_DTYPE *__restrict amplitudes, int64_t n) {
349+
DW_DTYPE const *__restrict const sigmaxx,
350+
int64_t const *__restrict const locations,
351+
DW_DTYPE *__restrict const amplitudes, int64_t n) {
351352
int64_t i;
352353
#pragma omp simd
353354
for (i = 0; i < n; ++i) {
@@ -1231,7 +1232,7 @@ __declspec(dllexport)
12311232

12321233
#define SAVE_SNAPSHOT(name, grad_cond) \
12331234
if (grad_cond) { \
1234-
int64_t step_idx = t / step_ratio; \
1235+
int64_t const step_idx = t / step_ratio; \
12351236
storage_save_snapshot_cpu( \
12361237
name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \
12371238
storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \
@@ -1656,7 +1657,7 @@ __declspec(dllexport)
16561657
: 0)) * \
16571658
(int64_t)shot_bytes_comp; \
16581659
if ((grad_cond) && ((t % step_ratio) == 0)) { \
1659-
int64_t step_idx = t / step_ratio; \
1660+
int64_t const step_idx = t / step_ratio; \
16601661
storage_load_snapshot_cpu( \
16611662
(void *)name##_store_1_t, name##_store_2_t, fp_##name, storage_mode, \
16621663
storage_compression, step_idx, shot_bytes_uncomp, shot_bytes_comp, \

0 commit comments

Comments
 (0)