Source code for torch_ecg.databases.cpsc_databases.cpsc2021

# -*- coding: utf-8 -*-

import json
import math
import os
import time
import warnings
from numbers import Real
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import scipy.io as sio
import wfdb

from ...cfg import CFG, DEFAULTS
from ...utils.misc import add_docstring, get_record_list_recursive3, ms2samples
from ...utils.utils_interval import generalized_intervals_intersection
from ..base import DEFAULT_FIG_SIZE_PER_SEC, DataBaseInfo, PhysioNetDataBase, WFDB_Beat_Annotations, _PlotCfg

__all__ = [
    "CPSC2021",
    "compute_metrics",
]


_CPSC2021_INFO = DataBaseInfo(
    title="""
    The 4th China Physiological Signal Challenge 2021:
    Paroxysmal Atrial Fibrillation Events Detection from Dynamic ECG Recordings
    """,
    about=r"""
    1. source ECG data are recorded from 12-lead Holter or 3-lead wearable ECG monitoring devices
    2. dataset provides variable-length ECG fragments extracted from lead I and lead II of the long-term source ECG data, each sampled at 200 Hz
    3. AF event is limited to be no less than 5 heart beats
    4. training set in the 1st stage consists of 730 records, extracted from the Holter records from 12 AF patients and 42 non-AF patients (usually including other abnormal and normal rhythms); training set in the 2nd stage consists of 706 records from 37 AF patients (18 PAF patients) and 14 non-AF patients
    5. test set comprises data from the same source as the training set as well as DIFFERENT data source, which are **NOT** to be released at any point
    6. annotations are standardized according to PhysioBank Annotations (Ref. [2]_ or :meth:`PhysioNetDataBase.helper`), and include the beat annotations (R peak location and beat type), the rhythm annotations (rhythm change flag and rhythm type) and the diagnosis of the global rhythm
    7. classification of a record is stored in corresponding .hea file, which can be accessed via the attribute `comments` of a wfdb Record obtained using :func:`wfdb.rdheader`, :func:`wfdb.rdrecord`, and :func:`wfdb.rdsamp`; beat annotations and rhythm annotations can be accessed using the attributes `symbol`, `aux_note` of a ``wfdb`` Annotation obtained using :func:`wfdb.rdann`, corresponding indices in the signal can be accessed via the attribute `sample`
    8. challenge task:

        - clasification of rhythm types: non-AF rhythm (N), persistent AF rhythm (AFf) and paroxysmal AF rhythm (AFp)
        - locating of the onset and offset for any AF episode prediction

    9. challenge metrics:

        - metrics (Ur, scoring matrix) for classification:

          .. tikz:: The scoring matrix for the recording-level classification result.
            :align: center
            :libs: positioning

            \tikzstyle{rect} = [rectangle, text width = 50, text centered, inner sep = 3pt, minimum height = 50]
            \tikzstyle{txt} = [rectangle, text centered, inner sep = 3pt, minimum height = 1.5]

            \node[rect, fill = green!25] at (0,0) (31) {$-0.5$};
            \node[rect, fill = green!10, right = 0 of 31] (32) {$0$};
            \node[rect, fill = red!30, right = 0 of 32] (33) {$+1$};
            \node[rect, fill = green!40, above = 0 of 31] (21) {$-1$};
            \node[rect, fill = red!30, above = 0 of 32] (22) {$+1$};
            \node[rect, fill = green!10, above = 0 of 33] (23) {$0$};
            \node[rect, fill = red!30, above = 0 of 21] (11) {$+1$};
            \node[rect, fill = green!60, above = 0 of 22] (12) {$-2$};
            \node[rect, fill = green!40, above = 0 of 23] (13) {$-1$};

            \node[txt, below = 0 of 31] {N};
            \node[txt, below = 0 of 32] (anchor_h) {AF$_{\text{f}}$};
            \node[txt, below = 0 of 33] {AF$_{\text{p}}$};
            \node[txt, left = 0 of 31] {AF$_{\text{p}}$};
            \node[txt, left = 0 of 21] (anchor_v) {AF$_{\text{f}}$};
            \node[txt, left = 0 of 11] {N};

            \node[txt, below = 0 of anchor_h] {\large\textbf{Annotation (Label)}};
            \node[txt, left = 0.6 of anchor_v, rotate = 90, anchor = north] {\large\textbf{Prediction}};

        - metric (Ue) for detecting onsets and offsets for AF events (episodes): +1 if the detected onset (or offset) is within ±1 beat of the annotated position, and +0.5 if within ±2 beats.
        - final score (U):

          .. math::

            U = \dfrac{1}{N} \sum\limits_{i=1}^N \left( Ur_i + \dfrac{Ma_i}{\max\{Mr_i, Ma_i\}} \right)

          where :math:`N` is the number of records,
          :math:`Ma` is the number of annotated AF episodes,
          :math:`Mr` is the number of predicted AF episodes.

    10. Challenge official website [1]_. Webpage of the database on PhysioNet [2]_.
    """,
    note="""
    1. if an ECG record is classified as AFf, the provided onset and offset locations should be the first and last record points. If an ECG record is classified as N, the answer should be an empty list
    2. it can be inferred from the classification scoring matrix that the punishment of false negatives of AFf is very heavy, while mixing-up of AFf and AFp is not punished
    3. flag of atrial fibrillation and atrial flutter ("AFIB" and "AFL") in annotated information are seemed as the same type when scoring the method
    4. the 3 classes can coexist in ONE subject (not one record). For example, subject 61 has 6 records with label "N", 1 with label "AFp", and 2 with label "AFf"
    5. rhythm change annotations ("(AFIB", "(AFL", "(N" in the `aux_note` field or "+" in the `symbol` field of the annotation files) are inserted 0.15s ahead of or behind (onsets or offset resp.) of corresponding R peaks.
    6. some records are revised if there are heart beats of the AF episode or the pause between adjacent AF episodes less than 5. The id numbers of the revised records are summarized in the attached `REVISED_RECORDS`.
    """,
    usage=[
        "AF (event, fine) detection",
    ],
    references=[
        "http://icbeb2021.pastconf.com/CPSC2021",
        "https://www.physionet.org/content/cpsc2021/",
    ],
    doi="10.13026/ksya-qw89",
)


[docs] @add_docstring(_CPSC2021_INFO.format_database_docstring(), mode="prepend") class CPSC2021(PhysioNetDataBase): """ Parameters ---------- db_dir : `path-like`, optional Storage path of the database. If not specified, data will be fetched from Physionet. working_dir : `path-like`, optional Working directory, to store intermediate files and log files. verbose : int, default 1 Level of logging verbosity. kwargs : dict, optional Auxilliary key word arguments """ __name__ = "CPSC2021" def __init__( self, db_dir: Optional[Union[str, bytes, os.PathLike]] = None, working_dir: Optional[Union[str, bytes, os.PathLike]] = None, verbose: int = 1, **kwargs: Any, ) -> None: super().__init__( db_name="cpsc2021", db_dir=db_dir, working_dir=working_dir, verbose=verbose, **kwargs, ) self.db_dir_base = Path(self.db_dir) self.db_tranches = [ "training_I", "training_II", ] self.fs = 200 self.spacing = 1000 / self.fs self.rec_ext = "dat" self.ann_ext = "atr" self.header_ext = "hea" self.all_leads = ["I", "II"] self.rec_patterns_with_ext = f"^data_(?:\\d+)_(?:\\d+)\\.{self.rec_ext}$" self._labels_f2a = { # fullname to abbreviation "non atrial fibrillation": "N", "paroxysmal atrial fibrillation": "AFp", "persistent atrial fibrillation": "AFf", } self._labels_f2n = { # fullname to number "non atrial fibrillation": 0, "paroxysmal atrial fibrillation": 2, "persistent atrial fibrillation": 1, } self.nb_records = CFG({"training_I": 730, "training_II": 706}) self._all_records = CFG({t: [] for t in self.db_tranches}) self.__all_records = None self.__revised_records = [] self._all_subjects = CFG({t: [] for t in self.db_tranches}) self.__all_subjects = None self._subject_records = CFG({t: [] for t in self.db_tranches}) self._stats = pd.DataFrame() self._stats_columns = [ "record", "tranche", "subject_id", "record_id", "label", "fs", "sig_len", "sig_len_sec", "revised", ] self._df_records = pd.DataFrame() self._ls_rec() self._aggregate_stats() self._diagnoses_records_list = None self._ls_diagnoses_records() self._epsilon = 1e-7 # dealing with round(0.5) = 0, hence keeping accordance with output length of `resample_poly` # self.palette = {"spb": "yellow", "pvc": "red",} @property def all_records(self) -> List[str]: if self.__all_records is None: self._ls_rec() return self.__all_records def _ls_rec(self) -> None: """Find all records in the database directory and store them (path, metadata, etc.) in some private attributes. """ self._df_records = pd.DataFrame() self._df_records["path"] = get_record_list_recursive3(self.db_dir_base, self.rec_patterns_with_ext, relative=False) self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x)) self._df_records["record"] = self._df_records["path"].apply(lambda x: x.stem) self._df_records["subject_id"] = self._df_records["record"].apply(lambda rec: int(rec.split("_")[1])) self._df_records["record_id"] = self._df_records["record"].apply(lambda rec: int(rec.split("_")[2])) self._df_records["tranche"] = self._df_records["subject_id"].apply(lambda x: "training_I" if x <= 53 else "training_II") if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records.set_index("record", inplace=True) self._all_records = CFG({t: [] for t in self.db_tranches}) self._all_subjects = CFG({t: [] for t in self.db_tranches}) self._subject_records = CFG({t: [] for t in self.db_tranches}) for t in self.db_tranches: self._all_records[t] = sorted(self._df_records[self._df_records["tranche"] == t].index.tolist()) self._all_subjects[t] = sorted( list(set([self.get_subject_id(rec) for rec in self._all_records[t]])), key=lambda s: int(s), ) self._subject_records[t] = CFG( {sid: [rec for rec in self._all_records[t] if self.get_subject_id(rec) == sid] for sid in self._all_subjects[t]} ) self._all_records_inv = {r: t for t, l_r in self._all_records.items() for r in l_r} self._all_subjects_inv = {s: t for t, l_s in self._all_subjects.items() for s in l_s} self.__all_records = sorted(self._df_records.index.tolist()) self.__all_subjects = sorted(self._df_records["subject_id"].apply(str).unique().tolist()) def _aggregate_stats(self) -> None: """Aggregate stats on the whole dataset.""" stats_file = "stats.csv" stats_file_fp = self.db_dir_base / stats_file if stats_file_fp.is_file() and self._subsample is None: self._stats = pd.read_csv(stats_file_fp) if self._stats.empty or set(self._stats_columns) != set(self._stats.columns): self.logger.info("Please wait patiently to let the reader aggregate statistics on the whole dataset...") start = time.time() self._stats = pd.DataFrame(self.all_records, columns=["record"]) # use self.all_records to ensure it's computed self._stats["tranche"] = self._stats["record"].apply(lambda s: self._all_records_inv[s]) self._stats["subject_id"] = self._stats["record"].apply(lambda s: int(s.split("_")[1])) self._stats["record_id"] = self._stats["record"].apply(lambda s: int(s.split("_")[2])) self._stats["label"] = self._stats["record"].apply(lambda s: self.load_label(s)) self._stats["fs"] = self.fs self._stats["sig_len"] = self._stats["record"].apply( lambda s: wfdb.rdheader(str(self.get_absolute_path(s))).sig_len ) self._stats["sig_len_sec"] = self._stats["sig_len"] / self._stats["fs"] self._stats["revised"] = self._stats["record"].apply(lambda s: 1 if s in self.__revised_records else 0) self._stats = self._stats.sort_values(by=["subject_id", "record_id"], ignore_index=True) self._stats = self._stats[self._stats_columns] if self._subsample is None: self._stats.to_csv(stats_file_fp, index=False) self.logger.info(f"Done in {time.time() - start:.5f} seconds!") else: pass # currently no need to parse the loaded csv file self._stats["subject_id"] = self._stats["subject_id"].apply(lambda s: str(s)) @property def all_subjects(self) -> List[str]: return self.__all_subjects @property def subject_records(self) -> CFG: return self._subject_records @property def df_stats(self) -> pd.DataFrame: return self._stats def _ls_diagnoses_records(self) -> None: """List all the records for all diagnoses.""" fn = "diagnoses_records_list.json" dr_fp = self.db_dir_base / fn if dr_fp.is_file() and self._subsample is None: self._diagnoses_records_list = json.loads(dr_fp.read_text()) else: start = time.time() if self.df_stats.empty: self.logger.info("Please wait several minutes patiently to let the reader list records for each diagnosis...") self._diagnoses_records_list = {d: [] for d in self._labels_f2a.values()} for rec in self.all_records: lb = self.load_label(rec) self._diagnoses_records_list[lb].append(rec) self.logger.info(f"Done in {time.time() - start:.5f} seconds!") else: self._diagnoses_records_list = { d: self.df_stats[self.df_stats["label"] == d]["record"].tolist() for d in self._labels_f2a.values() } if self._subsample is None: dr_fp.write_text(json.dumps(self._diagnoses_records_list, ensure_ascii=False)) self._diagnoses_records_list = CFG(self._diagnoses_records_list) @property def diagnoses_records_list(self): if self._diagnoses_records_list is None: self._ls_diagnoses_records() return self._diagnoses_records_list
[docs] def get_subject_id(self, rec: Union[str, int]) -> str: """Attach a unique subject ID to the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. Returns ------- sid : str Subject ID corresponding to the record. """ if isinstance(rec, int): rec = self[rec] sid = rec.split("_")[1] return sid
[docs] def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = None) -> Path: """Get the absolute path of the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. extension : str, optional Extension of the file. Returns ------- abs_path : pathlib.Path Absolute path of the file. """ if isinstance(rec, int): rec = self[rec] abs_path = self._df_records.loc[rec, "path"] if extension is not None: if not extension.startswith("."): extension = f".{extension}" abs_path = abs_path.with_suffix(extension) return abs_path
def _validate_samp_interval( self, rec: Union[str, int], sampfrom: Optional[int] = None, sampto: Optional[int] = None, ) -> Tuple[int, int]: """Validate `sampfrom` and `sampto` so that they are reasonable. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. sampfrom : int, optional Start index of the data to be loaded. sampto : int, optional End index of the data to be loaded. Returns ------- sf : int Index sampling from. st : int Index sampling to. """ if isinstance(rec, int): rec = self[rec] sig_len = self.df_stats[self.df_stats.record == rec].iloc[0].sig_len sf, st = ( sampfrom if sampfrom is not None else 0, min(sampto, sig_len) if sampto is not None else sig_len, ) if sampto is not None and sampto > sig_len: warnings.warn( f"the end index {sampto} is larger than the signal length {sig_len}, " f"so it is set to {sig_len}", RuntimeWarning, ) if sf >= st: raise ValueError("Invalid `sampfrom` and `sampto`") return sf, st
[docs] def load_ann( self, rec: Union[str, int], field: Optional[Literal["rpeaks", "af_episodes", "label", "raw", "wfdb"]] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, **kwargs: Any, ) -> Union[dict, np.ndarray, List[List[int]], str]: """Load annotations of the record. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. field : {"rpeaks", "af_episodes", "label", "raw", "wfdb"}, optional Field of the annotation. If is None, all fields of the annotation will be returned in the form of a dict. If is "raw" or "wfdb", then the corresponding wfdb "Annotation" will be returned. sampfrom : int, optional Start index of the annotation to be loaded. sampto: int, optional End index of the annotation to be loaded. kwargs : dict Key word arguments for functions loading rpeaks, af_episodes, and label respectively, including: - fs: int, optional, the resampling frequency - fmt: str, format of af_episodes, or format of label, for more details, ref. corresponding functions. Used only when `field` is specified (not None). Returns ------- ann : dict or list or numpy.ndarray or str Annotaton of the record. """ sf, st = self._validate_samp_interval(rec, sampfrom, sampto) ann = wfdb.rdann( str(self.get_absolute_path(rec)), extension=self.ann_ext, sampfrom=sf, sampto=st, ) # `load_af_episodes` should not use sampfrom, sampto func = { "rpeaks": self.load_rpeaks, "af_episodes": self.load_af_episodes, "label": self.load_label, } if field is None: ann = {k: f(rec, ann, sf, st) for k, f in func.items()} if kwargs: warnings.warn( f"key word arguments `{list(kwargs.keys())}` ignored when `field` is not specified!", RuntimeWarning, ) return ann elif field.lower() in ["raw", "wfdb"]: return ann try: f = func[field.lower()] except Exception: raise ValueError(f"Invalid `field`: {field}") ann = f(rec, ann, sf, st, **kwargs) return ann
[docs] def load_rpeaks( self, rec: Union[str, int], ann: Optional[wfdb.Annotation] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, keep_original: bool = False, valid_only: bool = True, fs: Optional[Real] = None, ) -> np.ndarray: """Load position (in terms of samples) of rpeaks. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. ann : wfdb.Annotation, optional The wfdb Annotation of the record. If is None, corresponding annotation file will be read. sampfrom : int, optional Start index of the rpeak positions to be loaded. sampto : int, optional End index of the rpeak positions to be loaded. keep_original : bool, default False If True, indices will keep the same with the annotation file, otherwise subtract `sampfrom` if specified. valid_only : bool, default True If True, only valid rpeaks will be returned, otherwise, all indices in the `sample` field of the annotation will be returned. Valid rpeaks are those with symbol in `WFDB_Beat_Annotations`. Symbols in `WFDB_Non_Beat_Annotations` are considered as invalid rpeaks fs : numbers.Real, optional If not None, positions of the loaded rpeaks will be ajusted according to this sampling frequency. Returns ------- rpeaks : numpy.ndarray Position (in terms of samples) of rpeaks of the record. """ if ann is None: sf, st = self._validate_samp_interval(rec, sampfrom, sampto) ann = wfdb.rdann( str(self.get_absolute_path(rec)), extension=self.ann_ext, sampfrom=sf, sampto=st, ) critical_points = ann.sample symbols = ann.symbol if sampfrom and not keep_original: critical_points = critical_points - sampfrom if fs is not None and fs != self.fs: critical_points = np.round(critical_points * fs / self.fs + self._epsilon).astype(int) if valid_only: rpeaks_valid = np.isin(symbols, list(WFDB_Beat_Annotations.keys())) rpeaks = critical_points[rpeaks_valid] else: rpeaks = critical_points return rpeaks
[docs] @add_docstring(load_rpeaks.__doc__) def load_rpeak_indices( self, rec: Union[str, int], ann: Optional[wfdb.Annotation] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, keep_original: bool = False, valid_only: bool = True, fs: Optional[Real] = None, ) -> np.ndarray: """alias of `self.load_rpeaks`""" return self.load_rpeaks( rec=rec, ann=ann, sampfrom=sampfrom, sampto=sampto, keep_original=keep_original, valid_only=valid_only, fs=fs, )
[docs] def load_af_episodes( self, rec: Union[str, int], ann: Optional[wfdb.Annotation] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, keep_original: bool = False, fs: Optional[Real] = None, fmt: Literal["intervals", "mask", "c_intervals"] = "intervals", ) -> Union[List[List[int]], np.ndarray]: """Load the episodes of atrial fibrillation, in terms of intervals or mask. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. ann : wfdb.Annotation, optional The wfdb Annotation of the record. If is None, corresponding annotation file will be read. sampfrom : int, optional Start index of the AF episodes to be loaded. Not used when `fmt` is "c_intervals". sampto : int, optional End index of the AF episodes to be loaded. Not used when `fmt` is "c_intervals". keep_original : bool, default False If True, indices will keep the same with the annotation file, otherwise subtract `sampfrom` if specified. Valid only when `fmt` is not "c_intervals". fs : numbers.Real, optional If not None, positions of the loaded intervals or mask will be ajusted according to this sampling frequency. Otherwise, the sampling frequency of the record will be used. fmt : {"intervals", "mask", "c_intervals"}, optional Format of the episodes of atrial fibrillation, by default "intervals". Returns ------- af_episodes : list or numpy.ndarray Episodes of atrial fibrillation, in terms of intervals or mask. """ header = wfdb.rdheader(str(self.get_absolute_path(rec))) label = self._labels_f2a[header.comments[0]] siglen = header.sig_len _ann = wfdb.rdann(str(self.get_absolute_path(rec)), extension=self.ann_ext) sf, st = self._validate_samp_interval(rec, sampfrom, sampto) aux_note = np.array(_ann.aux_note) critical_points = _ann.sample af_start_inds = np.where((aux_note == "(AFIB") | (aux_note == "(AFL"))[0] # ref. NOTE 3. af_end_inds = np.where(aux_note == "(N")[0] assert len(af_start_inds) == len(af_end_inds), "unequal number of af period start indices and af period end indices" if fmt.lower() in [ "c_intervals", ]: if sf > 0 or st < siglen: raise ValueError("when `fmt` is `c_intervals`, `sampfrom` and `sampto` should never be used!") af_episodes = [[start, end] for start, end in zip(af_start_inds, af_end_inds)] return af_episodes intervals = [] for start, end in zip(af_start_inds, af_end_inds): itv = [critical_points[start], critical_points[end]] intervals.append(itv) intervals = generalized_intervals_intersection(intervals, [[sf, st]]) siglen = st - sf if fs is not None and fs != self.fs: siglen = self._round(siglen * fs / self.fs) sf = self._round(sf * fs / self.fs) if label == "AFf": # ref. NOTE. 1 of the class docstring # the `ann.sample` does not always satify this point after resampling intervals = [[sf, siglen - 1]] else: intervals = [ [ self._round(itv[0] * fs / self.fs), self._round(itv[1] * fs / self.fs), ] for itv in intervals ] if not keep_original: intervals = [[itv[0] - sf, itv[1] - sf] for itv in intervals] sf = 0 af_episodes = intervals if fmt.lower() in [ "mask", ]: mask = np.zeros((siglen,), dtype=int) for itv in intervals: mask[itv[0] - sf : itv[1] - sf] = 1 af_episodes = mask return af_episodes
[docs] def load_label( self, rec: Union[str, int], ann: Optional[wfdb.Annotation] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, fmt: str = "a", ) -> str: """Load (classifying) label of the record. The three classes are: - "non atrial fibrillation", - "paroxysmal atrial fibrillation", - "persistent atrial fibrillation". Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. ann : wfdb.Annotation, optional Not used, to keep in accordance with other methods. sampfrom : int, optional Not used, to keep in accordance with other methods. sampto : int, optional Not used, to keep in accordance with other methods. fmt : str, default "a" Format of the label, case in-sensitive, can be one of - "f", "fullname": the full name of the label - "a", "abbr", "abbrevation": abbreviation for the label - "n", "num", "number": class number of the label (in accordance with the settings of the offical class map) Returns ------- label : str Classifying label of the record. """ header = wfdb.rdheader(str(self.get_absolute_path(rec))) label = header.comments[0] if fmt.lower() in ["a", "abbr", "abbreviation"]: label = self._labels_f2a[label] elif fmt.lower() in ["n", "num", "number"]: label = self._labels_f2n[label] elif fmt.lower() not in ["f", "fullname"]: raise ValueError(f"format `{fmt}` of labels is not supported!") return label
[docs] def gen_endpoint_score_mask( self, rec: Union[str, int], bias: dict = {1: 1, 2: 0.5}, verbose: Optional[int] = None, ) -> Tuple[np.ndarray, np.ndarray]: """Generate the scoring mask for the onsets and offsets of af episodes. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. bias : dict, default {1: 1, 2: 0.5} Bias for the scoring of the onsets and offsets of af episodes. Keys are bias (with ±) in terms of number of rpeaks, and values are corresponding scores. verbose : int, optional Verbosity level. If is None, :attr:`self.verbose` will be used. Returns ------- onset_score_mask, offset_score_mask: Tuple[numpy.ndarray] 2-tuple of :class:`~numpy.ndarray`, which are the scoring mask for the onset and offsets predictions of af episodes. NOTE ---- The onsets in `af_intervals` are 0.15s ahead of the corresponding R peaks, while the offsets in `af_intervals` are 0.15s behind the corresponding R peaks. """ if isinstance(rec, int): rec = self[rec] masks = gen_endpoint_score_mask( siglen=self.df_stats[self.df_stats.record == rec].iloc[0].sig_len, critical_points=wfdb.rdann(str(self.get_absolute_path(rec)), extension=self.ann_ext).sample, af_intervals=self.load_af_episodes(rec, fmt="c_intervals"), bias=bias, verbose=verbose or self.verbose, ) return masks
[docs] def plot( self, rec: Union[str, int], data: Optional[np.ndarray] = None, ann: Optional[Dict[str, np.ndarray]] = None, ticks_granularity: int = 0, sampfrom: Optional[int] = None, sampto: Optional[int] = None, leads: Optional[Union[str, int, List[Union[str, int]]]] = None, waves: Optional[Dict[str, Sequence[int]]] = None, **kwargs, ) -> None: """Plot the signals of a record. plot the signals of a record or external signals (units in μV), with metadata (labels, episodes of atrial fibrillation, etc.), possibly also along with wave delineations. Parameters ---------- rec : str or int Record name or index of the record in :attr:`all_records`. data : numpy.ndarray, optional (2-lead) ECG signal to plot. Should be of the format "channel_first", and compatible with `leads`. If given, data of `rec` will not be used. This is useful when plotting filtered data. ann : dict, optional Annotations for `data`. Ignored if `data` is None. ticks_granularity : int, default 0 Granularity to plot axis ticks, the higher the more ticks. 0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks) sampfrom : int, optional Start index of the data to plot. sampto : int, optional End index of the data to plot. leads : str or List[str], optional Names of the leads to plot. waves : dict, optional Indices of the wave critical points, including "p_onsets", "p_peaks", "p_offsets", "q_onsets", "q_peaks", "r_peaks", "s_peaks", "s_offsets", "t_onsets", "t_peaks", "t_offsets" kwargs : dict, optional Additional keyword arguments to pass to :func:`matplotlib.pyplot.plot`. TODO ---- 1. Slice too long records, and plot separately for each segment. 2. Plot waves using :func:`~matplotlib.pyplot.axvspan`. NOTE ---- 1. `Locator` of ``plt`` has default `MAXTICKS` of 1000. If not modifying this number, at most 40 seconds of signal could be plotted once. 2. Raw data usually have very severe baseline drifts, hence the isoelectric line is not plotted. Contributors: Jeethan, and WEN Hao """ if isinstance(rec, int): rec = self[rec] if "plt" not in dir(): import matplotlib.pyplot as plt plt.MultipleLocator.MAXTICKS = 3000 _leads = self._normalize_leads(leads) if data is None: _data = self.load_data( rec, leads=_leads, data_format="channel_first", units="μV", sampfrom=sampfrom, sampto=sampto, ) else: units = self._auto_infer_units(data) self.logger.info(f"input data is auto detected to have units in {units}") if units.lower() == "mv": _data = 1000 * data else: _data = data assert _data.shape[0] == len( _leads ), f"number of leads from data of shape ({_data.shape[0]}) does not match the length ({len(_leads)}) of `leads`" sf, st = (sampfrom or 0), (sampto or len(_data)) if waves: if waves.get("p_onsets", None) and waves.get("p_offsets", None): p_waves = [[onset, offset] for onset, offset in zip(waves["p_onsets"], waves["p_offsets"])] elif waves.get("p_peaks", None): p_waves = [ [ max(0, p + ms2samples(_PlotCfg.p_onset, fs=self.fs)), min( _data.shape[1], p + ms2samples(_PlotCfg.p_offset, fs=self.fs), ), ] for p in waves["p_peaks"] ] else: p_waves = [] if waves.get("q_onsets", None) and waves.get("s_offsets", None): qrs = [[onset, offset] for onset, offset in zip(waves["q_onsets"], waves["s_offsets"])] elif waves.get("q_peaks", None) and waves.get("s_peaks", None): qrs = [ [ max(0, q + ms2samples(_PlotCfg.q_onset, fs=self.fs)), min( _data.shape[1], s + ms2samples(_PlotCfg.s_offset, fs=self.fs), ), ] for q, s in zip(waves["q_peaks"], waves["s_peaks"]) ] elif waves.get("r_peaks", None): qrs = [ [ max(0, r + ms2samples(_PlotCfg.qrs_radius, fs=self.fs)), min( _data.shape[1], r + ms2samples(_PlotCfg.qrs_radius, fs=self.fs), ), ] for r in waves["r_peaks"] ] else: qrs = [] if waves.get("t_onsets", None) and waves.get("t_offsets", None): t_waves = [[onset, offset] for onset, offset in zip(waves["t_onsets"], waves["t_offsets"])] elif waves.get("t_peaks", None): t_waves = [ [ max(0, t + ms2samples(_PlotCfg.t_onset, fs=self.fs)), min( _data.shape[1], t + ms2samples(_PlotCfg.t_offset, fs=self.fs), ), ] for t in waves["t_peaks"] ] else: t_waves = [] else: p_waves, qrs, t_waves = [], [], [] palette = { "p_waves": "cyan", "qrs": "green", "t_waves": "yellow", } plot_alpha = 0.4 if ann is None or data is None: _ann = self.load_ann(rec, sampfrom=sampfrom, sampto=sampto) rpeaks = _ann["rpeaks"] af_episodes = _ann["af_episodes"] af_episodes = [[itv[0] - sf, itv[1] - sf] for itv in af_episodes] label = _ann["label"] else: rpeaks = ann.get("rpeaks", []) af_episodes = ann.get("af_episodes", []) label = ann.get("label", "") nb_leads = len(_leads) line_len = self.fs * 25 # 25 seconds nb_lines = math.ceil(_data.shape[1] / line_len) bias_thr = 0.07 # winL = 0.06 # winR = 0.08 for idx in range(nb_lines): seg = _data[..., idx * line_len : (idx + 1) * line_len] secs = (sf + np.arange(seg.shape[1]) + idx * line_len) / self.fs fig_sz_w = int(round(DEFAULT_FIG_SIZE_PER_SEC * seg.shape[1] / self.fs)) # if same_range: # y_ranges = np.ones((seg.shape[0],)) * np.max(np.abs(seg)) + 100 # else: # y_ranges = np.max(np.abs(seg), axis=1) + 100 # fig_sz_h = 6 * y_ranges / 1500 fig_sz_h = 6 * sum([seg_lead.max() - seg_lead.min() + 200 for seg_lead in seg]) / 1500 fig, axes = plt.subplots(nb_leads, 1, sharex=True, figsize=(fig_sz_w, np.sum(fig_sz_h))) if nb_leads == 1: axes = [axes] for ax_idx in range(nb_leads): axes[ax_idx].plot(secs, seg[ax_idx], color="black", label=f"lead - {_leads[ax_idx]}") # axes[ax_idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red") # NOTE that `Locator` has default `MAXTICKS` equal to 1000 if ticks_granularity >= 1: axes[ax_idx].xaxis.set_major_locator(plt.MultipleLocator(0.2)) axes[ax_idx].yaxis.set_major_locator(plt.MultipleLocator(500)) axes[ax_idx].grid(which="major", linestyle="-", linewidth="0.5", color="red") if ticks_granularity >= 2: axes[ax_idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04)) axes[ax_idx].yaxis.set_minor_locator(plt.MultipleLocator(100)) axes[ax_idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black") # add extra info. to legend # https://stackoverflow.com/questions/16826711/is-it-possible-to-add-a-string-as-a-legend-item-in-matplotlib if label: axes[ax_idx].plot([], [], " ", label=f"label - {label}") seg_rpeaks = [r / self.fs for r in rpeaks if idx * line_len <= r < (idx + 1) * line_len] for r in seg_rpeaks: axes[ax_idx].axvspan( max(secs[0], r - bias_thr), min(secs[-1], r + bias_thr), color=palette["qrs"], alpha=0.3, ) seg_af_episodes = generalized_intervals_intersection( af_episodes, [[idx * line_len, (idx + 1) * line_len]], ) seg_af_episodes = [[itv[0] - idx * line_len, itv[1] - idx * line_len] for itv in seg_af_episodes] for itv_start, itv_end in seg_af_episodes: axes[ax_idx].plot( secs[itv_start:itv_end], seg[ax_idx, itv_start:itv_end], color="red", ) for w in ["p_waves", "qrs", "t_waves"]: for itv in eval(w): itv_start = max(0, itv[0] - idx * line_len) itv_end = min(itv[1] - idx * line_len, line_len) if not 0 <= itv_start < itv_end <= line_len: continue axes[ax_idx].axvspan( secs[itv[0] - idx * line_len], secs[itv[1] - idx * line_len], color=palette[w], alpha=plot_alpha, ) axes[ax_idx].legend(loc="upper left") axes[ax_idx].set_xlim(secs[0], secs[-1]) # axes[ax_idx].set_ylim(-y_ranges[ax_idx], y_ranges[ax_idx]) axes[ax_idx].set_xlabel("Time [s]") axes[ax_idx].set_ylabel("Voltage [μV]") plt.subplots_adjust(hspace=0.2) plt.show()
def _round(self, n: Real) -> int: """ dealing with round(0.5) = 0, hence keeping accordance with output length of `resample_poly` """ return int(round(n + self._epsilon)) @property def url_(self) -> str: """URL of the compressed database file. .. versionadded:: 0.0.5 """ if self._url_compressed is not None: return self._url_compressed # currently, cpsc2021 is not included in the list obtained # using `wfdb.get_dbs()` self._url_compressed = ( "https://www.physionet.org/static/published-projects/cpsc2021/" "paroxysmal-atrial-fibrillation-events-detection-from-dynamic-ECG-recordings" f"-the-4th-china-physiological-signal-challenge-2021-{self.version}.zip" ) return self._url_compressed @property def database_info(self) -> DataBaseInfo: return _CPSC2021_INFO
################################################################### # copied and modified from the official scoring code ################################################################### R = np.array([[1, -1, -0.5], [-2, 1, 0], [-1, 0, 1]]) # scoring matrix for classification class RefInfo: def __init__(self, sample_path): self.sample_path = sample_path ( self.fs, self.len_sig, self.beat_loc, self.af_starts, self.af_ends, self.class_true, ) = self._load_ref() self.endpoints_true = np.dstack((self.af_starts, self.af_ends))[0, :, :] # self.endpoints_true = np.concatenate((self.af_starts, self.af_ends), axis=-1) if self.class_true == 1 or self.class_true == 2: ( self.onset_score_range, self.offset_score_range, ) = self._gen_endpoint_score_range() else: self.onset_score_range, self.offset_score_range = None, None def _load_ref(self): sig, fields = wfdb.rdsamp(self.sample_path) ann_ref = wfdb.rdann(self.sample_path, "atr") fs = fields["fs"] length = len(sig) sample_descrip = fields["comments"] beat_loc = np.array(ann_ref.sample) # r-peak locations ann_note = np.array(ann_ref.aux_note) # rhythm change flag af_start_scripts = np.where((ann_note == "(AFIB") | (ann_note == "(AFL"))[0] af_end_scripts = np.where(ann_note == "(N")[0] if "non atrial fibrillation" in sample_descrip: class_true = 0 elif "persistent atrial fibrillation" in sample_descrip: class_true = 1 elif "paroxysmal atrial fibrillation" in sample_descrip: class_true = 2 else: print("Error: the recording is out of range!") return -1 return fs, length, beat_loc, af_start_scripts, af_end_scripts, class_true def _gen_endpoint_score_range(self, verbose=0): """ """ onset_range = np.zeros((self.len_sig,), dtype=float) offset_range = np.zeros((self.len_sig,), dtype=float) for i, af_start in enumerate(self.af_starts): if self.class_true == 2: if max(af_start - 1, 0) == 0: onset_range[: self.beat_loc[af_start + 2]] += 1 if verbose > 0: print(f"official --- onset (c_ind, score 1): 0 --- {af_start+2}") print(f"official --- onset (sample, score 1): 0 --- {self.beat_loc[af_start+2]}") elif max(af_start - 2, 0) == 0: onset_range[self.beat_loc[af_start - 1] : self.beat_loc[af_start + 2]] += 1 if verbose > 0: print(f"official --- onset (c_ind, score 1): {af_start-1} --- {af_start+2}") print( f"official --- onset (sample, score 1): {self.beat_loc[af_start-1]} --- {self.beat_loc[af_start+2]}" ) onset_range[: self.beat_loc[af_start - 1]] += 0.5 if verbose > 0: print(f"official --- onset (c_ind, score 0.5): 0 --- {af_start-1}") print(f"official --- onset (sample, score 0.5): 0 --- {self.beat_loc[af_start-1]}") else: onset_range[self.beat_loc[af_start - 1] : self.beat_loc[af_start + 2]] += 1 if verbose > 0: print(f"official --- onset (c_ind, score 1): {af_start-1} --- {af_start+2}") print( f"official --- onset (sample, score 1): {self.beat_loc[af_start-1]} --- {self.beat_loc[af_start+2]}" ) onset_range[self.beat_loc[af_start - 2] : self.beat_loc[af_start - 1]] += 0.5 if verbose > 0: print(f"official --- onset (c_ind, score 0.5): {af_start-2} --- {af_start-1}") print( f"official --- onset (sample, score 0.5): {self.beat_loc[af_start-2]} --- {self.beat_loc[af_start-1]}" ) onset_range[self.beat_loc[af_start + 2] : self.beat_loc[af_start + 3]] += 0.5 if verbose > 0: print(f"official --- onset (c_ind, score 0.5): {af_start+2} --- {af_start+3}") print( f"official --- onset (sample, score 0.5): {self.beat_loc[af_start+2]} --- {self.beat_loc[af_start+3]}" ) elif self.class_true == 1: onset_range[: self.beat_loc[af_start + 2]] += 1 if verbose > 0: print(f"official --- onset (c_ind, score 1): 0 --- {af_start+2}") print(f"official --- onset (sample, score 1): 0 --- {self.beat_loc[af_start+2]}") onset_range[self.beat_loc[af_start + 2] : self.beat_loc[af_start + 3]] += 0.5 if verbose > 0: print(f"official --- onset (c_ind, score 0.5): {af_start+2} --- {af_start+3}") print( f"official --- onset (sample, score 0.5): {self.beat_loc[af_start+2]} --- {self.beat_loc[af_start+3]}" ) for i, af_end in enumerate(self.af_ends): if self.class_true == 2: if min(af_end + 1, len(self.beat_loc) - 1) == len(self.beat_loc) - 1: offset_range[self.beat_loc[af_end - 2] :] += 1 if verbose > 0: print(f"official --- offset (c_ind, score 1): {af_end-2} --- -1") print(f"official --- offset (sample, score 1): {self.beat_loc[af_end-2]} --- -1") elif min(af_end + 2, len(self.beat_loc) - 1) == len(self.beat_loc) - 1: offset_range[self.beat_loc[af_end - 2] : self.beat_loc[af_end + 1]] += 1 if verbose > 0: print(f"official --- offset (c_ind, score 1): {af_end-2} --- {af_end+1}") print(f"official --- offset (sample, score 1): {self.beat_loc[af_end-2]} --- {self.beat_loc[af_end+1]}") offset_range[self.beat_loc[af_end + 1] :] += 0.5 if verbose > 0: print(f"official --- offset (c_ind, score 0.5): {af_end+1} --- -1") print(f"official --- offset (sample, score 0.5): {self.beat_loc[af_end+1]} --- -1") else: offset_range[self.beat_loc[af_end - 2] : self.beat_loc[af_end + 1]] += 1 if verbose > 0: print(f"official --- offset (c_ind, score 1): {af_end-2} --- {af_end+1}") print(f"official --- offset (sample, score 1): {self.beat_loc[af_end-2]} --- {self.beat_loc[af_end+1]}") offset_range[self.beat_loc[af_end + 1] : min(self.beat_loc[af_end + 2], self.len_sig - 1)] += 0.5 if verbose > 0: print(f"official --- offset (c_ind, score 0.5): {af_end+1} --- -1") print( f"official --- offset (sample, score 0.5): {self.beat_loc[af_end+1]} --- {min(self.beat_loc[af_end+2], self.len_sig-1)}" ) offset_range[self.beat_loc[af_end - 3] : self.beat_loc[af_end - 2]] += 0.5 if verbose > 0: print(f"official --- offset (c_ind, score 0.5): {af_end-3} --- {af_end-2}") print(f"official --- offset (sample, score 0.5): {self.beat_loc[af_end-3]} --- {self.beat_loc[af_end-2]}") elif self.class_true == 1: offset_range[self.beat_loc[af_end - 2] :] += 1 if verbose > 0: print(f"official --- offset (c_ind, score 1): {af_end-2} --- -1") print(f"official --- offset (sample, score 1): {self.beat_loc[af_end-2]} --- -1") offset_range[self.beat_loc[af_end - 3] : self.beat_loc[af_end - 2]] += 0.5 if verbose > 0: print(f"official --- offset (c_ind, score 0.5): {af_end-3} --- {af_end-2}") print(f"official --- offset (sample, score 0.5): {self.beat_loc[af_end-3]} --- {self.beat_loc[af_end-2]}") return onset_range, offset_range def load_ans(ans_file): endpoints_pred = [] if ans_file.endswith(".json"): json_file = open(ans_file, "r") ans_dic = json.load(json_file) endpoints_pred = np.array(ans_dic["predict_endpoints"]) elif ans_file.endswith(".mat"): ans_struct = sio.loadmat(ans_file) endpoints_pred = ans_struct["predict_endpoints"] - 1 return endpoints_pred def ue_calculate(endpoints_pred, endpoints_true, onset_score_range, offset_score_range): score = 0 ma = len(endpoints_true) mr = len(endpoints_pred) if mr == 0: score = 0 else: for [start, end] in endpoints_pred: score += onset_score_range[int(start)] score += offset_score_range[int(end)] score *= ma / max(ma, mr) return score def ur_calculate(class_true, class_pred): score = R[int(class_true), int(class_pred)] return score def score(data_path, ans_path): # AF burden estimation SCORE = [] def is_mat_or_json(file): return (file.endswith(".json")) + (file.endswith(".mat")) ans_set = filter(is_mat_or_json, os.listdir(ans_path)) # test_set = open(os.path.join(data_path, 'RECORDS'), 'r').read().splitlines() for i, ans_sample in enumerate(ans_set): sample_nam = ans_sample.split(".")[0] sample_path = os.path.join(data_path, sample_nam) endpoints_pred = load_ans(os.path.join(ans_path, ans_sample)) TrueRef = RefInfo(sample_path) if len(endpoints_pred) == 0: class_pred = 0 elif len(endpoints_pred) == 1 and np.diff(endpoints_pred)[-1] == TrueRef.len_sig - 1: class_pred = 1 else: class_pred = 2 ur_score = ur_calculate(TrueRef.class_true, class_pred) if TrueRef.class_true == 1 or TrueRef.class_true == 2: ue_score = ue_calculate( endpoints_pred, TrueRef.endpoints_true, TrueRef.onset_score_range, TrueRef.offset_score_range, ) else: ue_score = 0 u = ur_score + ue_score SCORE.append(u) score_avg = np.mean(SCORE) return score_avg ################################################################### # custom metric computing function ################################################################### def compute_challenge_metric( class_true: int, class_pred: int, endpoints_true: Sequence[Sequence[int]], endpoints_pred: Sequence[Sequence[int]], onset_score_range: Sequence[float], offset_score_range: Sequence[float], ) -> float: """ compute challenge metric for a single record Parameters ---------- class_true: int, labelled for the record class_pred: int, predicted class for the record endpoints_true: sequence of intervals, labelled intervals of AF episodes endpoints_pred: sequence of intervals, predicted intervals of AF episodes onset_score_range: sequence of float, scoring mask for the AF onset predictions offset_score_range: sequence of float, scoring mask for the AF offset predictions Returns ------- u: float, the final score for the prediction """ ur_score = ur_calculate(class_true, class_pred) ue_score = ue_calculate(endpoints_pred, endpoints_true, onset_score_range, offset_score_range) u = ur_score + ue_score return u def gen_endpoint_score_mask( siglen: int, critical_points: Sequence[int], af_intervals: Sequence[Sequence[int]], bias: dict = {1: 1, 2: 0.5}, verbose: int = 0, ) -> Tuple[np.ndarray, np.ndarray]: """ generate the scoring mask for the onsets and offsets of af episodes, Parameters ---------- siglen: int, length of the signal critical_points: sequence of int, locations (indices in the signal) of the critical points, including R peaks, rhythm annotations, etc, which are stored in the `sample` fields of an wfdb annotation file (corr. beat ann, rhythm ann are in the `symbol`, `aux_note` fields) af_intervals: sequence of intervals, intervals of the af episodes in terms of indices in `critical_points` bias: dict, default {1:1, 2:0.5}, keys are bias (with ±) in terms of number of rpeaks values are corresponding scores verbose: int, default 0, log verbosity Returns ------- (onset_score_mask, offset_score_mask): 2-tuple of ndarray, scoring mask for the onset and offsets predictions of af episodes NOTE ---- 1. the onsets in `af_intervals` are 0.15s ahead of the corresponding R peaks, while the offsets in `af_intervals` are 0.15s behind the corresponding R peaks. 2. for records [data_39_4,data_48_4,data_68_23,data_98_5,data_101_5,data_101_7,data_101_8,data_104_25,data_104_27], the official `RefInfo._gen_endpoint_score_range` slightly expands the scoring intervals at heads or tails of the records, which strictly is incorrect as defined in the `Scoring` section of the official webpage (http://2021.icbeb.org/CPSC2021) """ _critical_points = list(critical_points) if 0 not in _critical_points: _critical_points.insert(0, 0) _af_intervals = [[itv[0] + 1, itv[1] + 1] for itv in af_intervals] if verbose >= 2: print(f"0 added to _critical_points, len(_critical_points): {len(_critical_points)-1} ==> {len(_critical_points)}") else: _af_intervals = [[itv[0], itv[1]] for itv in af_intervals] # records with AFf mostly have `_critical_points` ending with `siglen-1` # but in some rare case ending with `siglen` if siglen - 1 in _critical_points: _critical_points[-1] = siglen if verbose >= 2: print(f"in _critical_points siglen-1 (={siglen-1}) changed to siglen (={siglen})") elif siglen in _critical_points: pass else: _critical_points.append(siglen) if verbose >= 2: print( f"siglen (={siglen}) appended to _critical_points, len(_critical_points): {len(_critical_points)-1} ==> {len(_critical_points)}" ) onset_score_mask, offset_score_mask = np.zeros((siglen,)), np.zeros((siglen,)) for b, v in bias.items(): mask_onset, mask_offset = np.zeros((siglen,)), np.zeros((siglen,)) for itv in _af_intervals: onset_start = _critical_points[max(0, itv[0] - b)] # note that the onsets and offsets in `_af_intervals` already occupy positions in `_critical_points` onset_end = _critical_points[min(itv[0] + 1 + b, len(_critical_points) - 1)] if verbose > 0: print(f"custom --- onset (c_ind, score {v}): {max(0, itv[0]-b)} --- {min(itv[0]+1+b, len(_critical_points)-1)}") print( f"custom --- onset (sample, score {v}): {_critical_points[max(0, itv[0]-b)]} --- {_critical_points[min(itv[0]+1+b, len(_critical_points)-1)]}" ) mask_onset[onset_start:onset_end] = v # note that the onsets and offsets in `af_intervals` already occupy positions in `_critical_points` offset_start = _critical_points[max(0, itv[1] - 1 - b)] offset_end = _critical_points[min(itv[1] + b, len(_critical_points) - 1)] if verbose > 0: print( f"custom --- offset (c_ind, score {v}): {max(0, itv[1]-1-b)} --- {min(itv[1]+b, len(_critical_points)-1)}" ) print( f"custom --- offset (sample, score {v}): {_critical_points[max(0, itv[1]-1-b)]} --- {_critical_points[min(itv[1]+b, len(_critical_points)-1)]}" ) mask_offset[offset_start:offset_end] = v onset_score_mask = np.maximum(onset_score_mask, mask_onset) offset_score_mask = np.maximum(offset_score_mask, mask_offset) return onset_score_mask, offset_score_mask # aliases gen_endpoint_score_range = gen_endpoint_score_mask compute_metrics = compute_challenge_metric