-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprocessor.py
267 lines (210 loc) · 8.19 KB
/
processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import pretty_midi
RANGE_NOTE_ON = 128
RANGE_NOTE_OFF = 128
RANGE_VEL = 32
RANGE_TIME_SHIFT = 100
START_IDX = {
'note_on': 0,
'note_off': RANGE_NOTE_ON,
'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF,
'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT
}
class SustainAdapter:
def __init__(self, time, type):
self.start = time
self.type = type
class SustainDownManager:
def __init__(self, start, end):
self.start = start
self.end = end
self.managed_notes = []
self._note_dict = {} # key: pitch, value: note.start
def add_managed_note(self, note: pretty_midi.Note):
self.managed_notes.append(note)
def transposition_notes(self):
for note in reversed(self.managed_notes):
try:
note.end = self._note_dict[note.pitch]
except KeyError:
note.end = max(self.end, note.end)
self._note_dict[note.pitch] = note.start
# Divided note by note_on, note_off
class SplitNote:
def __init__(self, type, time, value, velocity):
## type: note_on, note_off
self.type = type
self.time = time
self.velocity = velocity
self.value = value
def __repr__(self):
return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\
.format(self.time, self.type, self.value, self.velocity)
class Event:
def __init__(self, event_type, value):
self.type = event_type
self.value = value
def __repr__(self):
return '<Event type: {}, value: {}>'.format(self.type, self.value)
def to_int(self):
return START_IDX[self.type] + self.value
@staticmethod
def from_int(int_value):
info = Event._type_check(int_value)
return Event(info['type'], info['value'])
@staticmethod
def _type_check(int_value):
range_note_on = range(0, RANGE_NOTE_ON)
range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF)
range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT)
valid_value = int_value
if int_value in range_note_on:
return {'type': 'note_on', 'value': valid_value}
elif int_value in range_note_off:
valid_value -= RANGE_NOTE_ON
return {'type': 'note_off', 'value': valid_value}
elif int_value in range_time_shift:
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF)
return {'type': 'time_shift', 'value': valid_value}
else:
valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT)
return {'type': 'velocity', 'value': valid_value}
def _divide_note(notes):
result_array = []
notes.sort(key=lambda x: x.start)
for note in notes:
on = SplitNote('note_on', note.start, note.pitch, note.velocity)
off = SplitNote('note_off', note.end, note.pitch, None)
result_array += [on, off]
return result_array
def _merge_note(snote_sequence):
note_on_dict = {}
result_array = []
for snote in snote_sequence:
# print(note_on_dict)
if snote.type == 'note_on':
note_on_dict[snote.value] = snote
elif snote.type == 'note_off':
try:
on = note_on_dict[snote.value]
off = snote
if off.time - on.time == 0:
continue
result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time)
result_array.append(result)
except:
print('info removed pitch: {}'.format(snote.value))
return result_array
def _snote2events(snote: SplitNote, prev_vel: int):
result = []
if snote.velocity is not None:
modified_velocity = snote.velocity // 4
if prev_vel != modified_velocity:
result.append(Event(event_type='velocity', value=modified_velocity))
result.append(Event(event_type=snote.type, value=snote.value))
return result
def _event_seq2snote_seq(event_sequence):
timeline = 0
velocity = 0
snote_seq = []
for event in event_sequence:
if event.type == 'time_shift':
timeline += ((event.value+1) / 100)
if event.type == 'velocity':
velocity = event.value * 4
else:
snote = SplitNote(event.type, timeline, event.value, velocity)
snote_seq.append(snote)
return snote_seq
def _make_time_sift_events(prev_time, post_time):
time_interval = int(round((post_time - prev_time) * 100))
results = []
while time_interval >= RANGE_TIME_SHIFT:
results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1))
time_interval -= RANGE_TIME_SHIFT
if time_interval == 0:
return results
else:
return results + [Event(event_type='time_shift', value=time_interval-1)]
def _control_preprocess(ctrl_changes):
sustains = []
manager = None
for ctrl in ctrl_changes:
if ctrl.value >= 64 and manager is None:
# sustain down
manager = SustainDownManager(start=ctrl.time, end=None)
elif ctrl.value < 64 and manager is not None:
# sustain up
manager.end = ctrl.time
sustains.append(manager)
manager = None
elif ctrl.value < 64 and len(sustains) > 0:
sustains[-1].end = ctrl.time
return sustains
def _note_preprocess(susteins, notes):
note_stream = []
if susteins: # if the midi file has sustain controls
for sustain in susteins:
for note_idx, note in enumerate(notes):
if note.start < sustain.start:
note_stream.append(note)
elif note.start > sustain.end:
notes = notes[note_idx:]
sustain.transposition_notes()
break
else:
sustain.add_managed_note(note)
for sustain in susteins:
note_stream += sustain.managed_notes
else: # else, just push everything into note stream
for note_idx, note in enumerate(notes):
note_stream.append(note)
note_stream.sort(key= lambda x: x.start)
return note_stream
def encode_midi(file_path):
events = []
notes = []
mid = pretty_midi.PrettyMIDI(midi_file=file_path)
for inst in mid.instruments:
inst_notes = inst.notes
# ctrl.number is the number of sustain control. If you want to know abour the number type of control,
# see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2
ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64])
notes += _note_preprocess(ctrls, inst_notes)
dnotes = _divide_note(notes)
# print(dnotes)
dnotes.sort(key=lambda x: x.time)
# print('sorted:')
# print(dnotes)
cur_time = 0
cur_vel = 0
for snote in dnotes:
events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
events += _snote2events(snote=snote, prev_vel=cur_vel)
# events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time)
cur_time = snote.time
cur_vel = snote.velocity
return [e.to_int() for e in events]
def decode_midi(idx_array, file_path=None):
event_sequence = [Event.from_int(idx) for idx in idx_array]
# print(event_sequence)
snote_seq = _event_seq2snote_seq(event_sequence)
note_seq = _merge_note(snote_seq)
note_seq.sort(key=lambda x:x.start)
mid = pretty_midi.PrettyMIDI()
# if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
instument = pretty_midi.Instrument(0, False, "Composed by Super Piano Music Transformer AI")
instument.notes = note_seq
mid.instruments.append(instument)
if file_path is not None:
mid.write(file_path)
return mid
if __name__ == '__main__':
encoded = encode_midi('bin/ADIG04.mid')
print(encoded)
decided = decode_midi(encoded,file_path='bin/test.mid')
ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid')
print(ins)
print(ins.instruments[0])
for i in ins.instruments:
print(i.control_changes)
print(i.notes)