Skip to content

Commit 9dc2c48

Browse files
author
Simon Holliday
committed
- Refactored the MidiEvent dataclass to handle its own conversions to and from mido messages, eliminating repetitive if/elif blocks in sequencer.py
- Added `_has_pitch_at_beat()` helper to PatternBuilder and refactored the Euclidean, Bresenham, and Ghost Fill algorithms to use it, removing duplicated logic. - Deleted the private _msg_to_midi_event helper in the Sequencer class since it was superseded by the cleaner factory method on MidiEvent
1 parent 137628c commit 9dc2c48

4 files changed

Lines changed: 109 additions & 110 deletions

File tree

subsequence/pattern_algorithmic.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def note (
4545
duration: float,
4646
) -> "PatternAlgorithmicMixin": ...
4747
def _resolve_pitch (self, pitch: typing.Union[int, str]) -> int: ...
48+
def _has_pitch_at_beat (self, pitch: typing.Union[int, str], beat: float) -> bool: ...
4849

4950
def _place_rhythm_sequence (
5051
self,
@@ -64,7 +65,6 @@ def _place_rhythm_sequence (
6465
across the pattern length. Zeros and dropout-gated steps are skipped.
6566
"""
6667

67-
midi_pitch = self._resolve_pitch(pitch)
6868
step_duration = self._pattern.length / len(sequence)
6969

7070
for i, hit_value in enumerate(sequence):
@@ -75,11 +75,8 @@ def _place_rhythm_sequence (
7575
if dropout > 0 and rng.random() < dropout:
7676
continue
7777

78-
if no_overlap:
79-
pulse = int(i * step_duration * subsequence.constants.MIDI_QUARTER_NOTE + 0.5)
80-
if pulse in self._pattern.steps:
81-
if any(n.pitch == midi_pitch for n in self._pattern.steps[pulse].notes):
82-
continue
78+
if no_overlap and self._has_pitch_at_beat(pitch, i * step_duration):
79+
continue
8380

8481
self.note(pitch=pitch, beat=i * step_duration, velocity=velocity, duration=duration)
8582

@@ -265,12 +262,8 @@ def bresenham_poly (
265262

266263
pitch = voice_names[voice_idx]
267264

268-
if no_overlap:
269-
midi_pitch = self._resolve_pitch(pitch)
270-
pulse = int(step_idx * step_duration * subsequence.constants.MIDI_QUARTER_NOTE + 0.5)
271-
if pulse in self._pattern.steps:
272-
if any(n.pitch == midi_pitch for n in self._pattern.steps[pulse].notes):
273-
continue
265+
if no_overlap and self._has_pitch_at_beat(pitch, step_idx * step_duration):
266+
continue
274267

275268
if isinstance(velocity, dict):
276269
vel = velocity.get(pitch, subsequence.constants.velocity.DEFAULT_VELOCITY)
@@ -483,7 +476,6 @@ def ghost_fill (
483476
if max_weight <= 0:
484477
return self
485478

486-
midi_pitch = self._resolve_pitch(pitch)
487479
step_duration = self._pattern.length / grid
488480

489481
for i in range(grid):
@@ -492,11 +484,8 @@ def ghost_fill (
492484
if rng.random() >= prob:
493485
continue
494486

495-
if no_overlap:
496-
pulse = int(round(i * step_duration * subsequence.constants.MIDI_QUARTER_NOTE))
497-
if pulse in self._pattern.steps:
498-
if any(n.pitch == midi_pitch for n in self._pattern.steps[pulse].notes):
499-
continue
487+
if no_overlap and self._has_pitch_at_beat(pitch, i * step_duration):
488+
continue
500489

501490
if callable(velocity):
502491
vel = int(velocity(i))

subsequence/pattern_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ def grid (self) -> int:
149149
"""Number of grid slots in this pattern (e.g. 16 for a 4-beat sixteenth-note pattern)."""
150150
return self._default_grid
151151

152+
def _has_pitch_at_beat (self, pitch: typing.Union[int, str], beat: float) -> bool:
153+
"""Helper to check if a pitch is already sounding at a specific beat."""
154+
midi_pitch = self._resolve_pitch(pitch)
155+
pulse = int(beat * subsequence.constants.MIDI_QUARTER_NOTE + 0.5)
156+
if pulse in self._pattern.steps:
157+
return any(n.pitch == midi_pitch for n in self._pattern.steps[pulse].notes)
158+
return False
159+
160+
152161
@property
153162
def c (self) -> typing.Optional[subsequence.conductor.Conductor]:
154163

subsequence/sequencer.py

Lines changed: 87 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,86 @@ class MidiEvent:
6060
device: int = dataclasses.field(compare=False, default=0)
6161

6262

63+
def to_mido (self) -> typing.Optional[typing.Union[mido.Message, mido.MetaMessage]]:
64+
65+
"""Convert this event to a mido.Message, or None if it's an internal type (like OSC)."""
66+
67+
if self.message_type in ('note_on', 'note_off'):
68+
return mido.Message(
69+
self.message_type,
70+
channel = self.channel,
71+
note = self.note,
72+
velocity = self.velocity
73+
)
74+
75+
if self.message_type == 'control_change':
76+
return mido.Message(
77+
'control_change',
78+
channel = self.channel,
79+
control = self.control,
80+
value = self.value
81+
)
82+
83+
if self.message_type == 'pitchwheel':
84+
return mido.Message(
85+
'pitchwheel',
86+
channel = self.channel,
87+
pitch = self.value
88+
)
89+
90+
if self.message_type == 'program_change':
91+
return mido.Message(
92+
'program_change',
93+
channel = self.channel,
94+
program = self.value
95+
)
96+
97+
if self.message_type == 'sysex':
98+
return mido.Message(
99+
'sysex',
100+
data = self.data if self.data is not None else b''
101+
)
102+
103+
return None
104+
105+
106+
@classmethod
107+
def from_mido (cls, pulse: int, msg: typing.Union[mido.Message, mido.MetaMessage], device: int = 0) -> "MidiEvent":
108+
109+
"""Convert a mido.Message to a MidiEvent."""
110+
111+
if msg.type == 'pitchwheel':
112+
return cls(
113+
pulse = pulse,
114+
message_type = 'pitchwheel',
115+
channel = msg.channel,
116+
value = msg.pitch,
117+
device = device,
118+
)
119+
120+
if msg.type == 'control_change':
121+
return cls(
122+
pulse = pulse,
123+
message_type = 'control_change',
124+
channel = msg.channel,
125+
control = msg.control,
126+
value = msg.value,
127+
device = device,
128+
)
129+
130+
return cls(
131+
pulse = pulse,
132+
message_type = msg.type,
133+
channel = getattr(msg, 'channel', 0),
134+
value = getattr(msg, 'value', 0),
135+
note = getattr(msg, 'note', 0),
136+
velocity = getattr(msg, 'velocity', 0),
137+
data = getattr(msg, 'data', None),
138+
control = getattr(msg, 'control', 0),
139+
device = device,
140+
)
141+
142+
63143
@dataclasses.dataclass
64144
class ScheduledPattern:
65145

@@ -571,36 +651,6 @@ def _on_midi_input (self, message: typing.Any, device_idx: int = 0) -> None:
571651
self._forward_buffer.append((self.pulse_count, out_msg))
572652

573653

574-
def _msg_to_midi_event (self, pulse: int, msg: typing.Any) -> MidiEvent:
575-
576-
"""Convert a mido.Message to a MidiEvent at the given pulse.
577-
578-
Used to inject queued CC forwards into the event heap.
579-
"""
580-
581-
if msg.type == 'pitchwheel':
582-
return MidiEvent(
583-
pulse = pulse,
584-
message_type = 'pitchwheel',
585-
channel = msg.channel,
586-
value = msg.pitch,
587-
)
588-
elif msg.type == 'control_change':
589-
return MidiEvent(
590-
pulse = pulse,
591-
message_type = 'control_change',
592-
channel = msg.channel,
593-
control = msg.control,
594-
value = msg.value,
595-
)
596-
else:
597-
return MidiEvent(
598-
pulse = pulse,
599-
message_type = msg.type,
600-
channel = getattr(msg, 'channel', 0),
601-
value = getattr(msg, 'value', 0),
602-
)
603-
604654

605655
def _estimate_bpm (self, tick_time: float) -> None:
606656

@@ -1280,7 +1330,7 @@ async def _process_pulse (self, pulse: int) -> None:
12801330
# while the callback thread calls append().
12811331
while self._forward_buffer:
12821332
fwd_pulse, fwd_msg = self._forward_buffer.popleft()
1283-
heapq.heappush(self.event_queue, self._msg_to_midi_event(fwd_pulse, fwd_msg))
1333+
heapq.heappush(self.event_queue, MidiEvent.from_mido(fwd_pulse, fwd_msg))
12841334

12851335
while self.event_queue and self.event_queue[0].pulse <= pulse:
12861336

@@ -1298,21 +1348,9 @@ async def _process_pulse (self, pulse: int) -> None:
12981348

12991349
if self.recording and event.message_type != 'osc':
13001350

1301-
if event.message_type in ('note_on', 'note_off'):
1302-
self._record_event(event.pulse, mido.Message(event.message_type, channel=event.channel, note=event.note, velocity=event.velocity))
1303-
1304-
elif event.message_type == 'control_change':
1305-
self._record_event(event.pulse, mido.Message('control_change', channel=event.channel, control=event.control, value=event.value))
1306-
1307-
elif event.message_type == 'pitchwheel':
1308-
self._record_event(event.pulse, mido.Message('pitchwheel', channel=event.channel, pitch=event.value))
1309-
1310-
elif event.message_type == 'program_change':
1311-
self._record_event(event.pulse, mido.Message('program_change', channel=event.channel, program=event.value))
1312-
1313-
elif event.message_type == 'sysex':
1314-
raw = event.data if event.data is not None else b''
1315-
self._record_event(event.pulse, mido.Message('sysex', data=raw))
1351+
mido_msg = event.to_mido()
1352+
if mido_msg is not None:
1353+
self._record_event(event.pulse, mido_msg)
13161354

13171355

13181356
async def _stop_all_active_notes (self) -> None:
@@ -1343,49 +1381,14 @@ def _send_midi (self, event: MidiEvent) -> None:
13431381

13441382
try:
13451383

1346-
if event.message_type in ('note_on', 'note_off'):
1347-
msg = mido.Message(
1348-
event.message_type,
1349-
channel = event.channel,
1350-
note = event.note,
1351-
velocity = event.velocity
1352-
)
1353-
1354-
elif event.message_type == 'control_change':
1355-
msg = mido.Message(
1356-
'control_change',
1357-
channel = event.channel,
1358-
control = event.control,
1359-
value = event.value
1360-
)
1361-
1362-
elif event.message_type == 'pitchwheel':
1363-
msg = mido.Message(
1364-
'pitchwheel',
1365-
channel = event.channel,
1366-
pitch = event.value
1367-
)
1368-
1369-
elif event.message_type == 'program_change':
1370-
msg = mido.Message(
1371-
'program_change',
1372-
channel = event.channel,
1373-
program = event.value
1374-
)
1375-
1376-
elif event.message_type == 'sysex':
1377-
msg = mido.Message(
1378-
'sysex',
1379-
data = event.data if event.data is not None else b''
1380-
)
1381-
1382-
elif event.message_type == 'osc':
1384+
if event.message_type == 'osc':
13831385
if self.osc_server is not None:
13841386
address, args = event.data
13851387
self.osc_server.send(address, *args)
13861388
return
13871389

1388-
else:
1390+
msg = event.to_mido()
1391+
if msg is None:
13891392
return
13901393

13911394
port.send(msg)

tests/test_cc_forward.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,23 +287,21 @@ async def test_queued_drained_in_process_pulse(patch_midi: None) -> None:
287287
assert any(m.control == 74 and m.value == 64 for m in spy.sent)
288288

289289

290-
def test_msg_to_midi_event_cc(patch_midi: None) -> None:
291-
"""_msg_to_midi_event should correctly convert a CC message."""
292-
seq = _make_sequencer(conftest.SpyMidiOut())
290+
def test_midi_event_from_mido_cc(patch_midi: None) -> None:
291+
"""from_mido should correctly convert a CC message."""
293292
msg = mido.Message('control_change', channel=1, control=74, value=100)
294-
event = seq._msg_to_midi_event(10, msg)
293+
event = subsequence.sequencer.MidiEvent.from_mido(10, msg)
295294
assert event.pulse == 10
296295
assert event.message_type == 'control_change'
297296
assert event.channel == 1
298297
assert event.control == 74
299298
assert event.value == 100
300299

301300

302-
def test_msg_to_midi_event_pitchwheel(patch_midi: None) -> None:
303-
"""_msg_to_midi_event should correctly convert a pitchwheel message."""
304-
seq = _make_sequencer(conftest.SpyMidiOut())
301+
def test_midi_event_from_mido_pitchwheel(patch_midi: None) -> None:
302+
"""from_mido should correctly convert a pitchwheel message."""
305303
msg = mido.Message('pitchwheel', channel=0, pitch=-4096)
306-
event = seq._msg_to_midi_event(5, msg)
304+
event = subsequence.sequencer.MidiEvent.from_mido(5, msg)
307305
assert event.pulse == 5
308306
assert event.message_type == 'pitchwheel'
309307
assert event.value == -4096

0 commit comments

Comments
 (0)