A Novel Transformer-Based CVD Diagnostic Tool

Introduction

Cardiovascular diseases (CVDs) are a group of conditions afflicting the heart and vasculature. CVDs accounted for 32% of all global deaths in 2019. Common examples of cardiovascular diseases include heart attacks (myocardial infarctions), hypertrophic cardiomyopathy, and coronary disease. They are the leading cause of death in the U.S, with one person dying every 34 seconds.

Electrocardiograms (ECG) are a common non-invasive way to diagnose CVD. An ECG generates waveforms by measuring electrical activity of 12 locations around the body. The timing, shape, and duration of an ECG waveform can give a physician insights into the condition of a patient. Particularly, a physician might examine the P, QRS, and T intervals to uncover abnormalities.

Problem

There are multiple ways that a hospital or medical facility might choose to use for CVD diagnosis using ECGs, however, they all have disadvantages and limitations.

ECGs are commonly interpreted by clinicians manually based on the morphological characteristics of a waveform. Manual diagnosis of CVDs is usually difficult because signals are mixed with noise, meaning there is a lot of subjectivity involved. Clinicians need a lot of training to manually interpret ECG output, making diagnoses difficult, especially in cases where resources or trained physicians are unavailable.. To address these issues, several machine learning methods have been used to classify certain types of CVDs.

LSTMs are a type of recurrent neural network (RNN) that are well-suited for sequences and time-series data. LSTMs have been used to interpret ECG signals in the past. However, LSTM-based interpretations fall short because of the model’s inability to process truly long-range dependencies, which is required for analyzing continuous ECG data, as seen in hospital settings or throughout disease prognosis.

A two-dimensional CNN is another type of machine learning method that has been employed to classify ECG. ECG plots are converted to 2D grayscale images. The images are cropped and augmented and passed through CNN-based models to classify the disease. A drawback with this method is that the converted images have extra white space that slows the processing. The model also can’t process dependencies in the ECG sequences due to the use of simple feature extraction. In addition, 2D CNNs are sensitive to image transformations, which limits the transferability of the model between different monitors with varying resolutions.

To address the lack of long-range dependencies between the data points in the ECG, encoder-only vanilla transformers have been employed for the task. A transformer is a stack of encoder blocks, each of which is composed of a self-attention layer and a FFN (feed forward neural network). Self attention is where the representation of a sequence is computed by relating different subsequences with the same sequence. Self attention uncovers previously unknown or hidden patterns and improves classification. However, there lies a severe disadvantage to transformers: there is a quadratic dependence of self attention compute and memory on input sequence length. This means that a transformer deployed in a clinical setting cannot infer truly long-range dependencies in data, instead, are constrained to a small range in time. The issues discussed above (inadaptability to data, resource intensive, inability to derive truly long range dependencies) are major flaws which need attention.

Solution Criteria

To overcome the limitations of classical transformers, I set the following criteria:

  • The model must process long range dependencies in ECG waveform sequences

  • Model must be able to process data of variable sequence lengths

  • The model must utilize lightweight architecture so that the technology can be employed in clinical settings on devices with limited storage and memory capacity.

  • The model must be repurposed and tuned to accurately classify ECG waveforms

The model would be evaluated against state-of-the-art CVD classifiers to determine whether my model is able to attain the same level of accuracy.

Solution

After reviewing current machine learning literature, one model stood out: MEGA (Moving Average Equipped Gated Attention) — a new transformer-based ML model that addresses several of the flaws of vanilla transformers. Mega is commonly used for natural language processing (NLP) tasks. In short, Mega uses gated, single-head attention to avoid the computation cost of multi-headed attention while still maintaining its expressivity.

Mega’s architecture can be summarized by the diagram

, where outputs from a bidirectional damped EMA are

passed into a chunk-wise self attention layer. Both of these will

be subsequently discussed.

Where ÎČ is a projection matrix.

Where is a decay factor, and is a damping factor. This formula could be compared to a summation where the farther away from a point is, the greater the effect of the decay factor .

Mega increases expressivity by introducing a multidimensional form of EMA. This is where each dimension of the input vector X has a separate EMA performed on it. First, X is expanded by:


Another improvement was the replacement of the vanilla transformers’ multihead attention with a single head attention block. Mega determined that the single head block attained the same expressivity as a multihead. However, most importantly, the utilization of a chunk-wise self attention layer reduced the time and memory complexity of the model. Vanilla transformers have an O(nÂČ) dependency on the sequence length n. However, Mega improves this by chunking the attention into k = n/c chunks. One can see that using O(knÂČ) we get O(n). This memory and space complexity is far better than vanilla transformers, allowing the model to more easily find long range dependencies with a larger range. Lastly, a final reason for choosing the Mega architecture is due to its performance on the Long Range Arena (LRA) benchmark, which is a set of tasks to assess a language model’s skills. Mega was the highest performing model in all categories.

Connections

Given all this information, one major question may arise. How can a language model be used for CVD classification? The answer lies in the fact that one can convert ECG data points into embeddings, which can be used in the model as sequences. For the task of CVD classification, an audio-based Mega architecture would be repurposed, so as to allow for minimal additional tuning. The reason for this is that both ECG and audio signals are processed in the same way — Fourier transformed, then sampled using a preset sampling rate.

Data and Benchmark

The model was trained using PTBXL, an ECG waveform dataset. PTBXL is one of the largest publicly available 12-lead ECG waveform dataset, containing over 21000 records, each annotated by 3 cardiologists. The data consists of 10 second long waveforms. Several models have been evaluated on PTBXL, displaying macro-averaged AUROC scores for all the diseases classified. Most scores were in the upper 80s to lower 90s. These scores would be used as a benchmark in order to evaluate the proposed model’s accuracy.

Preprocessing

Before the Mega model could be trained on ECG data, some initial setup steps were needed to ensure that proper comparisons could be made between the benchmark and the proposed model. Previously, Mega was utilizing a multihead multiclass classifier, which produced Accuracy and F1 scores. However, these would need to be modified into a single headed multiclass classifier. In addition, AUROC scores would need to be produced for each of the categories and macro-averaged at the end of each epoch. After making these modifications to the code, the model’s output would be compatible with the benchmark. In addition, a data loader would need to be created to load the training data into the model.

Implementation (Training and validation)

I would develop the model using a cloud computing platform called Paperspace. I chose to use a 45 GB Ram, 8 CPU, 16 GB GPU A4000 for development and running the code. Output (such as epoch graphs) would be generated using Weights and Biases. I chose to use a batch size for training of 20, as it would speed up the process while causing minimal changes to accuracy. The PTBXL records were then divided into a 80-10-10 fold split (train, validation, test). During training an ADAM (adaptive moment estimation) optimizer was used to optimize parameters, given by the formula:

Where Ξn is a vector of parameters after n steps of training, α is the step size, Δ is a term which prevents the divide by zero error, and n Îœ is the bias-corrected second raw moment estimate. The ADAM optimizer is useful due to its use of a momentum term (used to calculate the value of vₙ ), since it helps the model escape local minima within the loss function. During training, several hyperparameters would be tuned, including the learning rate, ÎČ₁, Δ, ÎČ₂, the encoder hidden dimension (or the expressivity of the model), and the number of hidden layers. After training, 5 binary classifications were produced, along with macro-averaged accuracy, F1, and AUROC scores.

This is a visualization of 10 sample runs of the model, depicting the change in validation AUROC across epochs. Each run lasted around 2 hours. The model was run 71 times in total, each one reaching their peak score around halfway through the run.

Results and Evaluation

ROC (Receiver-Operator Curves) are a threshold-agnostic visualization of the tradeoff between false positives and true positives. When developing ROCs to describe my model, I used a one-vs-rest method to contend with the multiple labels. This is where a multilabel classification scenario is turned into a binary one by considering only two classes: the one of interest, and all others. Below is a graph showing these ROCs. Five were generated, each for a particular condition of interest.

As one can see, the highest performing was the “Normal” category, and the lowest was “Hypertrophy”. All labels had an

AUROC ≄ 0.9.

Below are a set of confusion matrices depicting the model’s performance on each disease. The matrices correspond to “Normal”, “Myocardial Infarction”, “Supraventricular Tachycardia”, “Coronary Disease” and “Hypertrophy” respectively — left-to-right, top-to-bottom. Each matrix describes the accuracy of the model’s predictions on a particular category; the top left and bottom right squares are correct predictions, and the other squares are incorrect. The lighter colored a square is, the more accurate.

Belowis another interesting visualization — an attribution map. An attribution map is a unique way to add to the interpretability and research potential for the model. Simply put, an attribution map depicts the model’s “thought process” when making a diagnosis (this particular one is a heart attack). The model paid more attention to darker regions of the map, and less attention to lighter regions when making a classification.

Here is a comparison of my model (ECGenius) to other state-of-the-art machine learning models evaluated on PTB-XL. As you can see, the overall accuracy achieved on the dataset was 0.9214, ranking 1st out of all the other SOTA models. With more rigorous training of the model, I plan to improve the accuracy of ECGenius even further!

Website for Real Time Results and Diagnosis

To increase the accessibility of my technology, I developed the website (screenshot below) that could be used by healthcare workers in cases such as remote or continuous monitoring, or if no physician is available to make diagnoses. Below is an image of the website:

As you can see, for each patient, a diagnosis is provided, along with a visualization of the patient's ECG (which can be streamed via server connection). Additionally, files containing ECG data can be uploaded to the model in order to diagnose a disease. This website was developed using Javascript. To connect this website to my ML code, I needed to make a Flask server to run on my desktop, which would interact with the backend of this website. Additionally, this website could be expanded to a cloud-based web application.

Conclusion and Social Benefit

ECGenius allows for many advancements — both socially and in terms of research/technology. For instance, it allows for earlier diagnosis and management of CVDs by emergency responders, even before a physician’s assessment. In addition, the increased retention capacity allows discovery of longer-range patterns in ECGs that could indicate CVDs. Furthermore, the reduced space complexity invites the possibility of continuous or remote patient monitoring. Also, the improved speed may allow real-time diagnoses to be made in sliding window fashion. The lightweight architecture allows for deployment in hospitals or other clinical settings, and the website allows for increased accessibility, especially in places without access to clinicians.

In terms of technology and research, the application of a language model to the field of healthcare uncovers new avenues for future research and exploration. Also the attribution maps developed can facilitate the communication of complex scientific results to a wider audience through visual representations. Additionally, there are ways to repurpose my model for various other conditions that I was not able to include in my research, thereby expanding the diversity of conditions that could be identified. Finally, I made an open source contribution: I modified the Mega repository to support generic multilabel tasks.

Biases, Limitations and Future Directions

While there are many benefits and improvements made by ECGenius, there are also a few limitations. For instance, with relation to its application to real patient data. If the model is passed a novel input that is not similar to things that the model has seen before, an inaccurate classification may be produced. However, this can be solved by implementing a Reject option for the model. Lastly, steps can be taken to allow the model to run on CPU, so it can be more easily deployed to hospitals or other clinical facilities. Some of these include model quantization (reducing the space taken by variables), knowledge distillation (eliminating direct training of a “student model”) and model pruning (deleting parameters and weights unused in the final output.

Previous
Previous

Research - Role of Vasoconstriction and Dysregulated Ion Channels in the Protection and Growth of Glioblastoma Multiforme Tumors

Next
Next

Microfluidics based Braille display