Skip to content

Commit 0ad941f

Browse files
committed
wave: add format to getparams
1 parent 1701bf5 commit 0ad941f

3 files changed

Lines changed: 47 additions & 5 deletions

File tree

Lib/test/audiotests.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@ def check_params(self, f, nchannels, sampwidth, framerate, nframes,
3838

3939
params = f.getparams()
4040
self.assertEqual(params,
41-
(nchannels, sampwidth, framerate, nframes, comptype, compname))
41+
(nchannels, sampwidth, framerate, nframes, comptype, compname, format))
4242
self.assertEqual(params.nchannels, nchannels)
4343
self.assertEqual(params.sampwidth, sampwidth)
4444
self.assertEqual(params.framerate, framerate)
4545
self.assertEqual(params.nframes, nframes)
4646
self.assertEqual(params.comptype, comptype)
4747
self.assertEqual(params.compname, compname)
48+
self.assertEqual(params.format, format)
4849

4950
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
5051
dump = pickle.dumps(params, proto)
@@ -304,6 +305,8 @@ def test_read(self):
304305
f.setpos(f.getnframes() + 1)
305306

306307
def test_copy(self):
308+
if self.readonly:
309+
self.skipTest('Read only file format')
307310
f = self.f = self.module.open(self.sndfilepath)
308311
fout = self.fout = self.module.open(TESTFN, 'wb')
309312
fout.setparams(f.getparams())

Lib/test/test_wave.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,38 @@ def test__all__(self):
173173

174174
class WaveLowLevelTest(unittest.TestCase):
175175

176+
def test_setparams_6_tuple_defaults_to_pcm(self):
177+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
178+
filename = fp.name
179+
self.addCleanup(unlink, filename)
180+
181+
with wave.open(filename, 'wb') as w:
182+
w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
183+
w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed'))
184+
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_PCM)
185+
186+
def test_setparams_7_tuple_uses_format(self):
187+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
188+
filename = fp.name
189+
self.addCleanup(unlink, filename)
190+
191+
with wave.open(filename, 'wb') as w:
192+
w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed',
193+
wave.WAVE_FORMAT_IEEE_FLOAT))
194+
self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
195+
196+
def test_getparams_has_format_field(self):
197+
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
198+
filename = fp.name
199+
self.addCleanup(unlink, filename)
200+
201+
with wave.open(filename, 'wb') as w:
202+
w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed',
203+
wave.WAVE_FORMAT_IEEE_FLOAT))
204+
params = w.getparams()
205+
self.assertEqual(params.format, wave.WAVE_FORMAT_IEEE_FLOAT)
206+
self.assertEqual(params[:6], (1, 2, 22050, 0, 'NONE', 'not compressed'))
207+
176208
def test_getformat_setformat(self):
177209
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
178210
filename = fp.name

Lib/wave.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ class Error(Exception):
9393
_array_fmts = None, 'b', 'h', None, 'i'
9494

9595
_wave_params = namedtuple('_wave_params',
96-
'nchannels sampwidth framerate nframes comptype compname')
96+
'nchannels sampwidth framerate nframes comptype compname format',
97+
defaults=(WAVE_FORMAT_PCM,))
9798

9899

99100
def _byteswap(data, width):
@@ -349,7 +350,8 @@ def getcompname(self):
349350
def getparams(self):
350351
return _wave_params(self.getnchannels(), self.getsampwidth(),
351352
self.getframerate(), self.getnframes(),
352-
self.getcomptype(), self.getcompname())
353+
self.getcomptype(), self.getcompname(),
354+
self.getformat())
353355

354356
def setpos(self, pos):
355357
if pos < 0 or pos > self._nframes:
@@ -552,20 +554,25 @@ def getcompname(self):
552554
return self._compname
553555

554556
def setparams(self, params):
555-
nchannels, sampwidth, framerate, nframes, comptype, compname = params
556557
if self._datawritten:
557558
raise Error('cannot change parameters after starting to write')
559+
if len(params) == 6:
560+
nchannels, sampwidth, framerate, nframes, comptype, compname = params
561+
format = WAVE_FORMAT_PCM
562+
else:
563+
nchannels, sampwidth, framerate, nframes, comptype, compname, format = params
558564
self.setnchannels(nchannels)
559565
self.setsampwidth(sampwidth)
560566
self.setframerate(framerate)
561567
self.setnframes(nframes)
562568
self.setcomptype(comptype, compname)
569+
self.setformat(format)
563570

564571
def getparams(self):
565572
if not self._nchannels or not self._sampwidth or not self._framerate:
566573
raise Error('not all parameters set')
567574
return _wave_params(self._nchannels, self._sampwidth, self._framerate,
568-
self._nframes, self._comptype, self._compname)
575+
self._nframes, self._comptype, self._compname, self._format)
569576

570577
def tell(self):
571578
return self._nframeswritten

0 commit comments

Comments
 (0)