Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/asr/emformer_rnnt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(self, input):


class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
def __init__(self, optimizer, warmup_updates, last_epoch=-1):
self.warmup_updates = warmup_updates
super().__init__(optimizer, last_epoch=last_epoch)

Expand Down
5 changes: 1 addition & 4 deletions examples/asr/librispeech_conformer_rnnt/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""

def __init__(
Expand All @@ -38,12 +36,11 @@ def __init__(
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
super().__init__(optimizer, last_epoch=last_epoch)

def get_lr(self):
if self._step_count < self.force_anneal_step:
Expand Down
5 changes: 1 addition & 4 deletions examples/asr/librispeech_conformer_rnnt_biasing/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""

def __init__(
Expand All @@ -37,12 +35,11 @@ def __init__(
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
super().__init__(optimizer, last_epoch=last_epoch)

def get_lr(self):
if self._step_count < self.force_anneal_step:
Expand Down
3 changes: 1 addition & 2 deletions examples/avsr/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ def __init__(
total_epochs: int,
steps_per_epoch: int,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_epochs * steps_per_epoch
self.total_steps = total_epochs * steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
super().__init__(optimizer, last_epoch=last_epoch)

def get_lr(self):
if self._step_count < self.warmup_steps:
Expand Down
6 changes: 2 additions & 4 deletions examples/hubert/lightning_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ def __init__(
warmup_updates: int,
max_updates: int,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.max_updates = max_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
super().__init__(optimizer, last_epoch=last_epoch)

def get_lr(self):
if self._step_count <= self.warmup_updates:
Expand All @@ -62,15 +61,14 @@ def __init__(
init_lr_scale: float = 0.01,
final_lr_scale: float = 0.05,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.hold_updates = hold_updates
self.decay_updates = decay_updates
self.init_lr_scale = init_lr_scale
self.final_lr_scale = final_lr_scale

super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
super().__init__(optimizer, last_epoch=last_epoch)

def get_lr(self):
if self._step_count <= self.warmup_updates:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ def __init__(
warmup_updates: int,
max_updates: int,
last_epoch: int = -1,
verbose: bool = False,
):
self.warmup_updates = warmup_updates
self.max_updates = max_updates
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
super().__init__(optimizer, last_epoch=last_epoch)

def get_lr(self):
if self._step_count <= self.warmup_updates:
Expand Down
4 changes: 1 addition & 3 deletions src/torchaudio/datasets/cmuarctic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import csv
import os
from pathlib import Path
from typing import Tuple, Union
Expand Down Expand Up @@ -129,8 +128,7 @@ def __init__(
self._text = os.path.join(self._path, self._folder_text, self._file_text)

with open(self._text, "r") as text:
walker = csv.reader(text, delimiter="\n")
self._walker = list(walker)
self._walker = [[line.strip()] for line in text if line.strip()]

def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
"""Load the n-th sample from the dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@skipIfNoExec("sox")
@skipIfNoFFmpeg
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")
_info = staticmethod(partial(get_info_func(), backend="ffmpeg"))

def test_pathlike(self):
"""FFmpeg dispatcher can query audio data from pathlike object"""
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_gsm(self):
@skipIfNoExec("sox")
@skipIfNoFFmpeg
class TestInfoOpus(PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")
_info = staticmethod(partial(get_info_func(), backend="ffmpeg"))

@parameterized.expand(
list(
Expand All @@ -341,7 +341,7 @@ def test_opus(self, bitrate, num_channels, compression_level):
@skipIfNoExec("sox")
@skipIfNoFFmpeg
class TestLoadWithoutExtension(PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")
_info = staticmethod(partial(get_info_func(), backend="ffmpeg"))

def test_mp3(self):
"""MP3 file without extension can be loaded
Expand Down Expand Up @@ -405,7 +405,7 @@ def read(self, n):

@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")
_info = staticmethod(partial(get_info_func(), backend="ffmpeg"))

def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
Expand Down Expand Up @@ -557,7 +557,7 @@ def test_tarfile(self, ext, dtype):
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")
_info = staticmethod(partial(get_info_func(), backend="ffmpeg"))

def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
Expand Down Expand Up @@ -600,7 +600,7 @@ def test_requests(self, ext, dtype):
@skipIfNoExec("sox")
@skipIfNoFFmpeg
class TestInfoNoSuchFile(PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")
_info = staticmethod(partial(get_info_func(), backend="ffmpeg"))

def test_info_fail(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


class LoadTestBase(TempDirMixin, PytorchTestCase):
_load = partial(get_load_func(), backend="ffmpeg")
_load = staticmethod(partial(get_load_func(), backend="ffmpeg"))

def assert_format(
self,
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_amb(self, dtype, num_channels, normalize, sample_rate=8000):
@skipIfNoExec("sox")
@skipIfNoFFmpeg
class TestLoadWithoutExtension(PytorchTestCase):
_load = partial(get_load_func(), backend="ffmpeg")
_load = staticmethod(partial(get_load_func(), backend="ffmpeg"))

def test_mp3(self):
"""MP3 file without extension can be loaded
Expand Down Expand Up @@ -364,7 +364,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""

_load = partial(get_load_func(), backend="ffmpeg")
_load = staticmethod(partial(get_load_func(), backend="ffmpeg"))

@parameterized.expand(
[
Expand Down Expand Up @@ -541,7 +541,7 @@ def read(self, n):
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
_load = partial(get_load_func(), backend="ffmpeg")
_load = staticmethod(partial(get_load_func(), backend="ffmpeg"))

@parameterized.expand(
[
Expand Down Expand Up @@ -606,7 +606,7 @@ def test_frame(self, frame_offset, num_frames):
@skipIfNoExec("sox")
@skipIfNoFFmpeg
class TestLoadNoSuchFile(PytorchTestCase):
_load = partial(get_load_func(), backend="ffmpeg")
_load = staticmethod(partial(get_load_func(), backend="ffmpeg"))

def test_load_fail(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt


class SaveTestBase(TempDirMixin, TorchaudioTestCase):
_save = partial(get_save_func(), backend="ffmpeg")
_save = staticmethod(partial(get_save_func(), backend="ffmpeg"))

def assert_save_consistency(
self,
Expand Down Expand Up @@ -398,7 +398,7 @@ def test_save_multi_channels(self, num_channels):
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `self._save`"""

_save = partial(get_save_func(), backend="ffmpeg")
_save = staticmethod(partial(get_save_func(), backend="ffmpeg"))

@parameterized.expand([(True,), (False,)], name_func=name_func)
def test_save_channels_first(self, channels_first):
Expand Down Expand Up @@ -444,7 +444,7 @@ def test_save_tensor_preserve(self, dtype):
@skipIfNoExec("sox")
@skipIfNoFFmpeg
class TestSaveNonExistingDirectory(PytorchTestCase):
_save = partial(get_save_func(), backend="ffmpeg")
_save = staticmethod(partial(get_save_func(), backend="ffmpeg"))

def test_save_fail(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="soundfile")
_info = staticmethod(partial(get_info_func(), backend="soundfile"))

@parameterize(
["float32", "int32", "int16", "uint8"],
Expand Down Expand Up @@ -127,7 +127,7 @@ class MockSoundFileInfo:

@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="soundfile")
_info = staticmethod(partial(get_info_func(), backend="soundfile"))

def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __exit__(self, *args, **kwargs):


class MockedLoadTest(PytorchTestCase):
_load = partial(get_load_func(), backend="soundfile")
_load = staticmethod(partial(get_load_func(), backend="soundfile"))

def assert_dtype(self, ext, dtype, sample_rate, num_channels, normalize, channels_first):
"""When format is WAV or NIST, normalize=False will return the native dtype Tensor, otherwise float32"""
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_flac(self, sample_rate, num_channels, normalize, channels_first):


class LoadTestBase(TempDirMixin, PytorchTestCase):
_load = partial(get_load_func(), backend="soundfile")
_load = staticmethod(partial(get_load_func(), backend="soundfile"))

def assert_wav(
self,
Expand Down Expand Up @@ -272,7 +272,7 @@ def test_flac(self, dtype, sample_rate, num_channels, channels_first):
class TestLoadFormat(TempDirMixin, PytorchTestCase):
"""Given `format` parameter, `so.load` can load files without extension"""

_load = partial(get_load_func(), backend="soundfile")
_load = staticmethod(partial(get_load_func(), backend="soundfile"))
original = None
path = None

Expand Down Expand Up @@ -314,7 +314,7 @@ def test_flac(self, format_):

@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
_load = partial(get_load_func(), backend="soundfile")
_load = staticmethod(partial(get_load_func(), backend="soundfile"))

def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class MockedSaveTest(PytorchTestCase):
_save = partial(get_save_func(), backend="soundfile")
_save = staticmethod(partial(get_save_func(), backend="soundfile"))

@nested_params(
["float32", "int32", "int16", "uint8"],
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_ogg(self, dtype, sample_rate, num_channels, channels_first):

@skipIfNoModule("soundfile")
class SaveTestBase(TempDirMixin, PytorchTestCase):
_save = partial(get_save_func(), backend="soundfile")
_save = staticmethod(partial(get_save_func(), backend="soundfile"))

def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`self._save` can save wav format."""
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_ogg(self, sample_rate, num_channels):
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `self._save`"""

_save = partial(get_save_func(), backend="soundfile")
_save = staticmethod(partial(get_save_func(), backend="soundfile"))

@parameterize([True, False])
def test_channels_first(self, channels_first):
Expand All @@ -279,7 +279,7 @@ def test_channels_first(self, channels_first):

@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
_save = partial(get_save_func(), backend="soundfile")
_save = staticmethod(partial(get_save_func(), backend="soundfile"))

def _test_fileobj(self, ext):
"""Saving audio to file-like object works"""
Expand Down
10 changes: 5 additions & 5 deletions test/torchaudio_unittest/backend/dispatcher/sox/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
@skipIfNoExec("sox")
@skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
_info = staticmethod(partial(get_info_func(), backend="sox"))

@parameterized.expand(
list(
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_htk(self):
@disabledInCI
@skipIfNoSoxDecoder("opus")
class TestInfoOpus(PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
_info = staticmethod(partial(get_info_func(), backend="sox"))

@parameterized.expand(
list(
Expand Down Expand Up @@ -321,7 +321,7 @@ def read(self, n):
@skipIfNoSox
@skipIfNoExec("sox")
class TestFileObject(FileObjTestBase, PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
_info = staticmethod(partial(get_info_func(), backend="sox"))

def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_fileobj(self, ext, dtype):
@skipIfNoExec("sox")
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
_info = staticmethod(partial(get_info_func(), backend="sox"))

def _query_http(self, ext, dtype, sample_rate, num_channels, num_frames):
audio_path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_requests(self, ext, dtype):

@skipIfNoSox
class TestInfoNoSuchFile(PytorchTestCase):
_info = partial(get_info_func(), backend="sox")
_info = staticmethod(partial(get_info_func(), backend="sox"))

def test_info_fail(self):
"""
Expand Down
Loading