dcase2024のメディアンフィルタの実装を見る
from メディアンフィルタを実装してみよう
イベント予測でメディアンフィルタをどう使うのか
FMSG-JLESS SUBMISSION FOR DCASE 2024 TASK4 ON SOUND EVENT DETECTION WITH HETEROGENEOUS TRAINING DATASET AND POTENTIALLY MISSING LABELSの説明が丁寧で良い
メディアンフィルタでどこからどこまでがイベントかを処理し,出力
細かいフレーム単位で判定
判定結果をフィルタで統合
欠点は以下
閾値を低くして感度を上げると,ノイズや静音を拾い,検出したイベントの時間が延長
逆に高くして精度を上げると,小さい音が拾えず,イベントの時間が短縮
実装を見にいく
code: python
class ClassWiseMedianFilter:
def __init__(self, filter_lens=(1, 1, 1)):
self.filter_lens = filter_lens
def __call__(self, x, **kwargs):
out = []
for indx_cls in range(x.shape-1):
smoothed = median_filter(xindx_clsNone,
(self.filter_lensindx_cls, 1)):, 0
out.append(smoothed)
out = np.stack(out, -1)
return out
メディアンフィルタそのものではなく,それを使った場所を見にいく
code: python
batched_decode_preds(
strong_preds_studentmask_strong,
filenames_strong,
self.encoder,
median_filter=self.median_filter,
thresholds=[],
)
strong_predsの内部構造が分からないが,おそらく各イベントのフレーム単位の予測をもらい,平滑化している?
code: python
def batched_decode_preds(
strong_preds,
filenames,
encoder,
thresholds=0.5,
median_filter=None,
pad_indx=None,
):
"""Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a
dictionary
Args:
strong_preds: torch.Tensor, batch of strong predictions.
filenames: list, the list of filenames of the current batch.
encoder: ManyHotEncoder object, object used to decode predictions.
thresholds: list, the list of thresholds to be used for predictions.
median_filter: int, the number of frames for which to apply median window (smoothing).
pad_indx: list, the list of indexes which have been used for padding.
Returns:
dict of predictions, each keys is a threshold and the value is the DataFrame of predictions.
"""
# Init a dataframe per threshold
scores_raw = {}
scores_postprocessed = {}
prediction_dfs = {}
for threshold in thresholds:
prediction_dfsthreshold = pd.DataFrame()
for j in range(strong_preds.shape0): # over batches
audio_id = Path(filenamesj).stem
filename = audio_id + ".wav"
c_scores = strong_predsj
if pad_indx is not None:
# pad_indxがリスト形式の場合とtensor形式の場合を対応
if isinstance(pad_indxj, list):
true_len = int(c_scores.shape-1 * pad_indxj0)
else:
true_len = int(c_scores.shape-1 * pad_indxj.item())
c_scores = c_scores:true_len
c_scores = c_scores.transpose(0, 1).detach().cpu().numpy()
scores_rawaudio_id = create_score_dataframe(
scores=c_scores,
timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
event_classes=encoder.labels,
)
if median_filter is not None:
c_scores = median_filter(c_scores)
scores_postprocessedaudio_id = create_score_dataframe(
scores=c_scores,
timestamps=encoder._frame_to_time(np.arange(len(c_scores) + 1)),
event_classes=encoder.labels,
)
for c_th in thresholds:
pred = c_scores > c_th
pred = encoder.decode_strong(pred)
pred = pd.DataFrame(pred, columns="event_label", "onset", "offset")
pred"filename" = filename
prediction_dfsc_th = pd.concat(
[prediction_dfsc_th, pred], ignore_index=True
)
return scores_raw, scores_postprocessed, prediction_dfs