Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This code was written with PyTorch<0.4, but most people must be using PyTorch>=0
* [tensorboardX](https://github.com/lanpa/tensorboard-pytorch) (TensorBoard for PyTorch)

### 2. Train DnCNN-S (DnCNN with known noise level)
```
```shell
python train.py \
--preprocess True \
--num_of_layers 17 \
Expand All @@ -31,7 +31,7 @@ python train.py \
* *noiseL* is used for training and *val_noiseL* is used for validation. They should be set to the same value for unbiased validation. You can set whatever noise level you need.

### 3. Train DnCNN-B (DnCNN with blind noise level)
```
```shell
python train.py \
--preprocess True \
--num_of_layers 20 \
Expand All @@ -44,7 +44,7 @@ python train.py \
* *noiseL* is ingnored when training DnCNN-B. You can set *val_noiseL* to whatever you need.

### 4. Test
```
```shell
python test.py \
--num_of_layers 17 \
--logdir logs/DnCNN-S-15 \
Expand Down Expand Up @@ -77,7 +77,7 @@ python test.py \
## Tricks useful for boosting performance
* Parameter initialization:
Use *kaiming_normal* initialization for *Conv*; Pay attention to the initialization of *BatchNorm*
```
```python
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
Expand All @@ -90,11 +90,11 @@ def weights_init_kaiming(m):
```
* The definition of loss function
Set *size_average* to be False when defining the loss function. When *size_average=True*, the **pixel-wise average** will be computed, but what we need is **sample-wise average**.
```
```python
criterion = nn.MSELoss(size_average=False)
```
The computation of loss will be like:
```
```python
loss = criterion(out_train, noise) / (imgn_train.size()[0]*2)
```
where we divide the sum over one batch of samples by *2N*, with *N* being # samples.