AMP-Oriented Multi-task Model (AOMM)

Exploring antimicrobial peptides for multitask applications through computational modeling.

Scroll to see more

AMP-Oriented Multi-task Model (AOMM)

AMP-Oriented Multi-task Model (AOMM)

The dataset AMPOS presents a bioinformatics challenge, which is to achieve a comprehensive evaluation of antimicrobial peptides by demonstrating high performance below all subtasks of AMPPT (sequence mask prediction, AMP classification, half-life regression, microorganism-specific minimum inhibitory concentration regression, hemolytic activity score regression) in a neural network framework. To address this challenge, we proposed our own solution, the AMP-Oriented Multi-task Model (AOMM), which demonstrates the state-of-the-art performance for all subtasks of AMPPT.

  • The number of parameters is approximately 124M.
  • The model is open source on HuggingFace.
Model Usage
from transformers import AutoModel
model = AutoModel.from_pretrained(
    "muskwff/amp4multitask_124M",
    task_name="mic_regression",     # choose the task ["amp_classification", "hemolysis_regression", "mic_regression", "half_life_regression"]
    trust_remote_code=True      # required for loading the model
)
      

The code is available on GitLab.

Model Introduction


The model AOMM is a Transformer[1] model focusing on six important tasks regarding to AMP:

  1. sequence mask prediction: Masked sequence training is a common task in NLP(Auxiliary task)
  2. AMP classification: Determine whether a peptide is an antimicrobial peptide
  3. microorganism-specific minimum inhibitory concentration regression
  4. half-life regression
  5. hemolytic activity score regression
  6. the classifiaction Bioactivity(Auxiliary task)

[1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.

The overview of the AMP-Oriented Multi-task Model (AOMM)

Figure 1: The model architecture of AOMM

Model Architecture

The model architecture consists of two parts: 1. Encoder 2. Decoder, which means the model is a encoder-decoder model. The parameters of the model is shown as follows:

model_max_lengthnum_hidden_layershidden_sizenum_attention_heads
1281876824

We use the wohle sequence feature as the input of decoder for task mask training, while the input of the decoder for the other tasks is the CLS token of the sequence feature.

Encoder (AMP Layer)

The encoder is a transformer encoder, which is a self-attention mechanism. The encoder is composed of 18 layers, each layer contains a multi-head self-attention mechanism and a feed-forward network. The feed-forward network is composed of two linear layers, which is followed by a layer normalization. The input of the encoder is the input sequence. And there are two skip connections that used in a transformer layer. The shape of the output features is (batch_size, seq_len, hidden_size) [B, 128, 768].

alt text

Figure 2: The structure of an AMP layer

Encoder (AMP Layer) Code
class AMPLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = AMPAttention(config)
        self.attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
        self.output = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.activation = nn.GELU()
        self.config = config
        
    def forward(self, hidden_states, attention_mask=None):
        residual = hidden_states
        attention_output = self.attention(hidden_states, attention_mask)
        attention_output = self.dropout(attention_output)
        attention_output = residual + attention_output
        attention_output = self.attn_layer_norm(attention_output)
        
        residual = attention_output
        intermediate_output = self.intermediate(attention_output)
        intermediate_output = self.activation(intermediate_output)
        layer_output = self.output(intermediate_output)
        layer_output = self.dropout(layer_output)
        layer_output = residual + layer_output
        layer_output = self.final_layer_norm(layer_output)
        return layer_output
      

AMP Attention

AMP Attention is a multi-head self-attention mechanism. In the attention mechanism, we use rotation position encoding to get the positional relationship between input tokens.

alt text

Figure 3: The structure of AMP Attention

The self-attention operation process of RoPE is as follows: for each word embedding vector in the token sequence, first calculate its corresponding query and key vectors, then calculate the corresponding rotation position encoding for each token position, then apply the rotation transformation to the elements of the query and key vectors at each token position in pairs, and finally calculate the inner product between the query and key to obtain the self-attention calculation result.

Rotation transformation demonstration of the RoPE algorithm

Figure 4: Rotation transformation demonstration of the RoPE algorithm

The codes of RoPE in AMP attention are as follows:

AMP Attention Code
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=512):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        freqs_cis = precompute_freqs_cis(dim, max_seq_len)
        self.register_buffer("freqs_cis", freqs_cis, persistent=False)
    
    def forward(self, seq_len=None):
        seq_len = seq_len if seq_len is not None else self.max_seq_len
        return self.freqs_cis[:seq_len]
      

Decoder (Task-Aware Head)

The decoder is a Coupled Multilayer Perceptron(MLP) block containing six sub-blocks for each task. Except for the MIC regression task, the decoder of other tasks is a single MLP layer.

alt text

Figure 5: The structure and parameter configuration of the Decoder

ForMIC regression task, we combine the sequence feature with an organism feature. Thus, we design anorganism_encoder to to gain the organism featurea and an interactor to fuse the sequence feature and organism feature.

ActivityRegressionHead Code
class ActivityRegressionHead(nn.Module):
    def __init__(self, embedding_dim, num_organisms):
        super().__init__()        
        self.organism_encoder = nn.Sequential(
            nn.Embedding(num_organisms, embedding_dim),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim, embedding_dim),
            nn.GELU(),
            nn.Dropout(0.1),
        )

        self.interactor = nn.Sequential(
            nn.Linear(embedding_dim * 2, embedding_dim * 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(embedding_dim * 2, embedding_dim),
            nn.GELU(),
            nn.Dropout(0.2)
        )

        self.regressor = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim // 2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(embedding_dim // 2, embedding_dim // 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(embedding_dim // 4, 1)
        )

    def forward(self, features, organism_ids):
        org_features = self.organism_encoder(organism_ids)
        combined = torch.cat([features, org_features], dim=1)
        interaction = self.interactor(combined)
        interaction += features  # residual fusion with sequence features
        return self.regressor(interaction)

And all decoders were stored in the AMPForMultiTask block as a dictionary.

Register Task Heads in AMPForMultiTask
self.task_heads = nn.ModuleDict()
for task_name, task_config in task_config.items():
    num_orgs = task_config.get("num_organisms")
    self.task_heads[task_name] = TaskAwareHead(
        config,
        str(task_name),
        num_organisms=num_orgs
    )

Tokenizer Configuration

The tokenizer we use in the AOMM is the same as Esm2 model, which you can find on HuggingFace. However, because the sequence length of antimicrobial peptides is generally between 1 and 100 AA. Thus we set the max_length to 128.

Tokenizer Usage
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("muskwff/amp4multitask_124M")
Tokenizer Config (JSON)
{
  "model_max_length": 128,
  "tokenizer_class": "EsmTokenizer",
  "special_tokens_map": {
    "cls_token": "<cls>",
    "pad_token": "<pad>",
    "eos_token": "<eos>",
    "unk_token": "<unk>",
    "mask_token": "<mask>"
  }
}

AMPOS and AMPPT


Source: SPADE On: Hugging Face Owner: XJTLU_software 2025 License: Apache-2.0 Tags: biology Size: 10K < n < 100K

Overview

  • AMP-Oriented Six tasks (AMPOS) is the data to train AOMM.
  • The dataset is originated from SPADE and open source at Hugging Face.
  • AMPOS dataset bring a deep learning challenge AMP-oriented Multi-Property Prediction Task (AMPPT) to the community.

The ownership of the SPADE Database and AMP_six_tasks dataset belongs to the iGEM team XJTLU_software 2025. If you have any questinons or suggestions, please contact us at igem@xjtlu.edu.cn(Please indicate your organization and purpose of the email).

Dataset Structure (merged_data.json)
{
  "Sequence": {
    "task_name": "label"
  },
  "KGGK": {
    "amp_classification": 1.0,
    "bioactivity_classification": [
      "Anti-Gram+",
      "Anti-Gram-",
      "Antifungal",
      "Anti-Mammalian Cell"
    ],
    "hemolysis_regression": 0.24475088281037147,
    "mic_regression": {
      "Cryptococcus neoformans": 582.69,
      "Bacillus megaterium": 776.92,
      "Candida albicans": 4855.75,
      "Micrococcus luteus": 3107.68,
      "Acinetobacter baumannii": 4855.75,
      "Escherichia coli": 1165.38,
      "Pseudomonas aeruginosa": 2427.875,
      "Staphylococcus aureus": 2427.875
    }
  }
}

Details for AMPOS and AMPPT

The length of the peptides in the dataset is between1 and 100.

amp_classification

LabelAmountSource
AMP42,542SPADE Database
Non-AMP49,513Uniprot (keywords search)

bioactivity_classification

We unified labels from the SPADE database. In total, there are 68 labels. The list below shows the ids and the amount of the tags whose amount is greater than 100:

LabelIDAmount
Anti-Gram+011,078
Anti-Gram-111,277
Anti-MRSA2251
Anti-Mammalian Cell34,857
Anti-bacterial47,654
Anti-biofilm5243
Anti-cancer64,676
Anti-fungal75,876
Antimicrobial823,622
Anti-parasitic9361
Anti-tumor10124
Anti-viral113,599
Insecticidal12305

There are 35282 AMPs with bioactivity tags. Every AMP contains at least one tag.

Preprocessing before training

We retain the 13 tags from the above list and convert the label to a one-dimensional tensor of length 13.

Preprocessing: bioactivity_classification
if 'bioactivity_classification' in label:
    multihot = [0] * len(bioactivity_projection)
    for l in label['bioactivity_classification']:
        if l in bioactivity_projection:
            multihot[bioactivity_projection[l]] = 1
    label['bioactivity_classification'] = torch.tensor(multihot, dtype=torch.float)

Example:

Example
["Anti-Gram+"] 
--->
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

half_life_regression

Unit: min

We selected in vitro experimental data as labels and the hosts were all mammalian cells. The number of samples is 8497.

Preprocess for training

We performed -log10 on all values to reduce the range and use Z-Score normalization to normalize the data.

Preprocessing: half_life_regression
def calculate_parameters4half_life(labels):
    values = []
    for label in labels:
        if 'half_life_regression' in label and label['half_life_regression'] is not None:
            values.append(label['half_life_regression'])
    assert len(values) > 0, "No half life values found"
    converted = -np.log10(np.array(values))
    mean = np.mean(converted)
    std = np.std(converted)
    return mean, std
Half-life distribution

Figure 1: Half life distribution after log transformation

hemolysis_regression

We convert the lethality and lethal concentration data of antimicrobial peptides into a hemolytic activity score ranging from 0 to 1.

The formulas are as follows:

Hemolysis Score Formula
E_clipped = max(1, min(99, E))
EC50 = C * (100 - E_clipped) / E_clipped
Activity = -log10(EC50)
MeanActivity = average(Activity)
Activity_clipped = max(-5, min(-1, MeanActivity))
Score = 0.1 + 0.8 * (Activity_clipped - (-5)) / ((-1) - (-5))

which E is the lethality(%), C is the lethal concentration(μg/ml) and N is the number of samples. The closer the score is to 1, the stronger the hemolytic activity is.

  • Distribution: Hemolysis distribution

    Figure 2: Hemolytic activity score distribution

  • Concentration vs score: Concentration vs score

    Figure 3: Score vs lethal concentration(μg/ml)

Preprocessing before training

We performed ln(x+1) on all values to reduce the range and use Z-Score normalization to normalize the data.

Preprocessing: hemolysis_regression
def calculate_parameters4hemolysis(labels):
    values = []
    for label in labels:
        if 'hemolysis_regression' in label and label['hemolysis_regression'] is not None:
            values.append(label['hemolysis_regression'])
    assert len(values) > 0, "No hemolysis values found"
    converted = np.log1p(values)
    mean = np.mean(converted)
    std = np.std(converted)
    return mean, std

mic_regression

Unit: μg/ml

There are 5824 AMPs containing mic_regression labels. An antimicrobial peptide may have multiple target organisms and their corresponding MIC values. We set the high threshold 100000 μg/ml and the low threshold 0.001 μg/ml to clip the original MIC values. We retained microorganisms that appeared more than 100 times. The list below shows the organisms that appears more than 500 times:

OrganismIDSamples
Acinetobacter baumannii0711
Bacillus subtilis11,515
Candida albicans21,314
Enterococcus faecalis3681
Escherichia coli44,461
Klebsiella pneumoniae5956
Micrococcus luteus6570
Pseudomonas aeruginosa72,551
Salmonella enterica8966
Staphylococcus aureus93,850
Staphylococcus epidermidis101,097

Data distribution:

MIC distribution

Figure 4: Overall MIC distribution

Preprocess for training

We retrain the organisms in the list above. The number of samples is 18,668. We performed -log10 on all values and applied strain-specific normalization.

Preprocess: Calculate MIC stats by organism
def calculate_mic_stats_by_organism(labels, organism_projection):
    idx_to_org = {idx: org for org, idx in organism_projection.items()}
    mic_values = {org_idx: [] for org_idx in idx_to_org.keys()}
    
    for label in labels:
        if 'mic_regression' in label:
            for org, value in label['mic_regression'].items():
                if org in organism_projection and value > 0:
                    org_idx = organism_projection[org]
                    log_val = -math.log10(value)
                    mic_values[org_idx].append(log_val)
    
    # calculate average value and mean standard deviation for each organism
    mic_stats = {}
    for org_idx, values in mic_values.items():
        if values:  # make sure there are values
            mean = np.mean(values)
            std = np.std(values)
            mic_stats[org_idx] = (mean, std)
    
    return mic_stats

How to use data to train model

Run merged_dataloader.py to generate data loader for a certain task:

Build dataloaders
if __name__ == '__main__':
    file_path = "merged_data.json"
    train_loader, val_loader, test_loader = build_dataloader(file_path, "bioactivity_classification", 64)

use function denormalize and the file normalization_parameters.csv to denormalize the output of the model for regression tasks

Denormalize helper
norm_params_path = "normalization_parameters.csv"       # change this to your path
norm_params = pd.read_csv(norm_params_path)
def denormalize(task_name, normalized_value, organism_id=None):
    """
    Denormalize the value for regression tasks
    """
    if isinstance(organism_id, torch.Tensor):
        organism_id = organism_id.cpu().item()

    if task_name == "mic_regression":
        org_name = norm_params.iloc[organism_id]['organism']
        params = norm_params[
            (norm_params['parameter_type'] == 'mic_regression') &
            (norm_params['organism'] == org_name)
        ]
        mean = params['mean'].values[0]
        std = params['std'].values[0]
        log_val = normalized_value * std + mean
        return 10 ** (-log_val)  # 10^(-log10(value)) = value
    elif task_name == "half_life_regression":
        params = norm_params[norm_params['parameter_type'] == 'half_life_regression']
        mean = params['mean'].values[0]
        std = params['std'].values[0]
        log_val = normalized_value * std + mean
        return 10 ** (-log_val)    
    elif task_name == "hemolysis_regression":
        params = norm_params[norm_params['parameter_type'] == 'hemolysis_regression']
        mean = params['mean'].values[0]
        std = params['std'].values[0]
        log_val = normalized_value * std + mean
        return np.expm1(log_val)  # e^(log_val) - 1

    return normalized_value  # the rest of tasks are not normalized

Training


Overview

During the whole training process, we use AdamW as the optimizer and initialize the epoch as 100.

  • We use the Early Stopping to stop the training when the validation loss doesn't decrease for 5 epochs to avoid overfitting.
  • We use gradient clipping to increase the stability of training.
Training Overview
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

We save the model parameters with the best validation loss.

Training file structure
- AOMM(download from huggingface)\
    - config.json
    - model.safetensors
    - modeling_amp.py
    - normalization_parameters.py
    - pytorch_model.bin
    - README.md
    - tokenizer_config.json
    - vocab.txt
- merged_data.json
- merged_dataloader.py
- pretrain.py
- Trainer.py

Loss functions for training

Loss functions measure the difference between a model's predictions and the true target values, guiding parameter updates during network training. Below are four commonly used loss functions in PyTorch (denoted by their nn module names) with mathematical explanations:

nn.CrossEntropyLoss

  • Core Principle
    Designed for multi-class classification tasks (where each sample belongs to exactly one class). It implicitly combines nn.LogSoftmax (to convert logits to normalized log-probabilities) and nn.NLLLoss (Negative Log-Likelihood Loss) to avoid numerical instability from separate softmax calculations.
  • Mathematical Expression
    For a single sample with:
    Logits (model outputs before softmax): \( z = [z_1, z_2, ..., z_C] \) (where \( C \) is the number of classes),
    True class label: \( y \) (an integer in \( [0, C-1] \), not one-hot encoded),
    the loss is calculated as:
    \[ \text{CrossEntropyLoss}(z, y) = -\log\left( \frac{e^{z_y}}{\sum_{i=1}^C e^{z_i}} \right) = -z_y + \log\left( \sum_{i=1}^C e^{z_i} \right) \]
    For a batch of \( N \) samples, the average loss is:
    \[ \text{CrossEntropyLoss}(Z, Y) = \frac{1}{N} \sum_{n=1}^N \left[ -Z_{n,Y_n} + \log\left( \sum_{i=1}^C e^{Z_{n,i}} \right) \right] \]
    where \( Z \in \mathbb{R}^{N \times C} \) is the batch of logits, and \( Y = [y_1, y_2, ..., y_N] \) is the batch of true labels.
  • Application Scenarios
    Image classification (e.g., ImageNet 1000-class tasks), text category classification (e.g., news topic labeling).

nn.BCEWithLogitsLoss

  • Core Principle
    Used for binary classification (two classes) or multi-label classification (each sample can belong to multiple classes). It integrates nn.Sigmoid (to map logits to probabilities in \( [0,1] \)) and nn.BCELoss (Binary Cross-Entropy Loss) to stabilize training.
  • Mathematical Expression
    For a single sample with:
    Logit (model output for one class): \( z \),
    True label: \( y \) (0 or 1, since it's binary per class), the loss is:
    \[ \text{BCEWithLogitsLoss}(z, y) = -\big[y \cdot \log(\sigma(z)) + (1-y) \cdot \log(1-\sigma(z))\big] \]
    where \(\sigma(z) = \frac{1}{1+e^{-z}}\) is the Sigmoid function (converts logits to probabilities).

    For a batch of \( N \) samples and \( K \) classes (multi-label scenario), the average loss is:
    \[ \text{BCEWithLogitsLoss}(Z, Y) = \frac{1}{N \cdot K} \sum_{n=1}^N \sum_{k=1}^K \Big[ -Y_{n,k} \cdot \log(\sigma(Z_{n,k})) - (1-Y_{n,k}) \cdot \log(1-\sigma(Z_{n,k})) \Big] \]
    where \( Z \in \mathbb{R}^{N \times K} \) is the batch of logits, and \( Y \in \{0,1\}^{N \times K} \) is the batch of true multi-label matrices.
  • Application Scenarios
    Binary tasks (e.g., spam detection, medical disease diagnosis), multi-label tasks (e.g., image tagging: a photo may contain "cat" and "tree" simultaneously).

nn.MSELoss

  • Core Principle
    Short for Mean Squared Error Loss, it measures the average of the squared differences between predictions and true values. Primarily used for regression tasks (predicting continuous values) due to its smooth gradient.
  • Mathematical Expression
    For a single sample with:
    Model prediction: \( \hat{y} \),
    True target value: \( y \) (a continuous number, e.g., house price, temperature),
    the loss is:
    \[ \text{MSELoss}(\hat{y}, y) = (\hat{y} - y)^2 \]
    For a batch of \( N \) samples, the average loss is:
    \[ \text{MSELoss}(\hat{Y}, Y) = \frac{1}{N} \sum_{n=1}^N (\hat{Y}_n - Y_n)^2 \]
    where \( \hat{Y} = [\hat{y}_1, \hat{y}_2, ..., \hat{y}_N] \) is the batch of predictions, and \( Y = [y_1, y_2, ..., y_N] \) is the batch of true values.
  • Application Scenarios
    Continuous value prediction (e.g., house price regression, time-series temperature forecasting, image super-resolution). Note: MSE is sensitive to outliers (large errors are amplified by squaring).

nn.HuberLoss

  • Core Principle
    A robust regression loss that balances the smoothness of MSE and the outlier resistance of MAE (Mean Absolute Error). It uses a hyperparameter \( \delta > 0 \) to switch between quadratic (MSE-like) and linear (MAE-like) loss:
    When the absolute error is small (\( |\hat{y} - y| \leq \delta \)): Uses MSE (smooth gradient for stable training).
    When the absolute error is large (\( |\hat{y} - y| > \delta \)): Uses MAE (avoids amplifying outlier errors).
  • Mathematical Expression
    For a single sample:
    \[ \text{HuberLoss}(\hat{y}, y; \delta) = \begin{cases} \frac{1}{2} (\hat{y} - y)^2 & \text{if } |\hat{y} - y| \leq \delta, \\ \delta \cdot |\hat{y} - y| - \frac{1}{2} \delta^2 & \text{if } |\hat{y} - y| > \delta. \end{cases} \]
    For a batch of \( N \) samples, the average loss is:
    \[ \text{HuberLoss}(\hat{Y}, Y; \delta) = \frac{1}{N} \sum_{n=1}^N \begin{cases} \frac{1}{2} (\hat{Y}_n - Y_n)^2 & \text{if } |\hat{Y}_n - Y_n| \leq \delta, \\ \delta \cdot |\hat{Y}_n - Y_n| - \frac{1}{2} \delta^2 & \text{if } |\hat{Y}_n - Y_n| > \delta. \end{cases} \]
  • Application Scenarios
    Regression tasks with outliers (e.g., sensor data prediction with noise, autonomous driving speed estimation), where MSE would be distorted by extreme values. The choice of \( \delta \) depends on the data: smaller \( \delta \) makes it more like MSE, while larger \( \delta \) makes it more like MAE.

Pretrain

Run pretrain.py to pretrain the model using Masked Language Modeling (MLM) for AMP-Oriented Multi-task Model (AOMM).

Pretrain Details

Learning rate: 5e-5

We sample a few tokens in each sequence for MLM training.

  • 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  • The rest of the time (10% of the time) we keep the masked input tokens unchanged
Preprocessing: Masking Tokens
def mask_tokens(self, inputs):
    """
    Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
    """
    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training (with probability `self.mask_prob`)
    special_tokens_mask = torch.zeros_like(inputs, dtype=torch.bool)
    for token_id in self.special_token_ids:
        special_tokens_mask |= (inputs == token_id)

    probability_matrix = torch.full(labels.shape, self.mask_prob)
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(torch.full(labels.shape, self.prob_replace_mask)).bool() & masked_indices
    inputs[indices_replaced] = self.mask_token_id

    # 10% of the time, we replace masked input tokens with random word
    current_prob = self.prob_replace_rand / (1 - self.prob_replace_mask)
    indices_random = torch.bernoulli(torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
    for token_id in self.special_token_ids:
        indices_random = indices_random & (inputs != token_id)
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels

The loss function is CrossEntropy Loss.

Loss Function
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

Multi-task Training

Run Trainer.py to perform multi-task training. The first 5 layers of the encoder are frozen during this stage.

Multi-task Training Details

  • Learning rate: 2e-5
  • During task switching, initialize the model with the best model from the previous task
Preprocessing: Freezing Encoder Layers
def freeze_encoder_layers(self, num_frozen_layers: int=5):
    """freeze some certain layers of the encoder"""
    assert num_frozen_layers <= len(self.model.amp.layers) and num_frozen_layers > 0, "Invalid number of frozen layers"

    # freeze the word embeddings
    for param in self.model.amp.word_embeddings.parameters():
        param.requires_grad = False

    # freeze the layer norm layer
    for param in self.model.amp.emb_layer_norm.parameters():
        param.requires_grad = False

    # freeze some layers of the encoder
    for i, layer in enumerate(self.model.amp.layers):
        if i < num_frozen_layers:
            for param in layer.parameters():
                param.requires_grad = False

During task switching, we initialize the model to the best model of the previous task.

Training Loop: Initialize from Previous Best
def train_all_tasks(self, epochs_per_task=100, start_from=None):
    if start_from:
        start_idx = self.task_order.index(start_from)
        tasks_to_train = self.task_order[start_idx:]
    else:
        tasks_to_train = self.task_order

    for i, task in enumerate(tasks_to_train):
        print(f"\n=== Starting Task {i+1}/{len(tasks_to_train)}: {task} ===")
        if i > 0:
            prev_task = self.task_order[i-1]
            checkpoint_path = os.path.join(self.save_dir, f"best_{prev_task}_model.pth")
            checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
            self.model.load_state_dict(checkpoint)

        self.train_task(task, epochs=epochs_per_task)
        print(f"Completed training for {task}\n")

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

Training Task Order

  • amp_classification
  • bioactivity_classification
  • half_life_regression
  • hemolysis_regression
  • mic_regression
TaskLoss FunctionBest Validation LossEpoch (Early Stop)
amp_classificationCrossEntropyLoss()0.10796
bioactivity_classificationBCEWithLogitsLoss()0.207611
half_life_regressionMSELoss()0.011816
hemolysis_regressionMSELoss()0.384130
mic_regressionHuberLoss(delta=1.0)0.086330