AMP-Oriented Multi-task Model (AOMM)
AMP-Oriented Multi-task Model (AOMM)
The dataset AMPOS presents a bioinformatics challenge, which is to achieve a comprehensive evaluation of antimicrobial peptides by demonstrating high performance below all subtasks of AMPPT (sequence mask prediction, AMP classification, half-life regression, microorganism-specific minimum inhibitory concentration regression, hemolytic activity score regression) in a neural network framework. To address this challenge, we proposed our own solution, the AMP-Oriented Multi-task Model (AOMM), which demonstrates the state-of-the-art performance for all subtasks of AMPPT.
- The number of parameters is approximately 124M.
- The model is open source on HuggingFace.
from transformers import AutoModel
model = AutoModel.from_pretrained(
"muskwff/amp4multitask_124M",
task_name="mic_regression", # choose the task ["amp_classification", "hemolysis_regression", "mic_regression", "half_life_regression"]
trust_remote_code=True # required for loading the model
)
The code is available on GitLab.
Model Introduction
The model AOMM is a Transformer[1] model focusing on six important tasks regarding to AMP:
- sequence mask prediction: Masked sequence training is a common task in NLP(Auxiliary task)
- AMP classification: Determine whether a peptide is an antimicrobial peptide
- microorganism-specific minimum inhibitory concentration regression
- half-life regression
- hemolytic activity score regression
- the classifiaction Bioactivity(Auxiliary task)
[1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
Figure 1: The model architecture of AOMM
Model Architecture
The model architecture consists of two parts: 1. Encoder 2. Decoder, which means the model is a encoder-decoder model. The parameters of the model is shown as follows:
| model_max_length | num_hidden_layers | hidden_size | num_attention_heads |
|---|---|---|---|
| 128 | 18 | 768 | 24 |
We use the wohle sequence feature as the input of decoder for task mask training, while the input of the decoder for the other tasks is the CLS token of the sequence feature.
Encoder (AMP Layer)
The encoder is a transformer encoder, which is a self-attention mechanism. The encoder is composed of 18 layers, each layer contains a multi-head self-attention mechanism and a feed-forward network. The feed-forward network is composed of two linear layers, which is followed by a layer normalization. The input of the encoder is the input sequence. And there are two skip connections that used in a transformer layer. The shape of the output features is (batch_size, seq_len, hidden_size) [B, 128, 768].
Figure 2: The structure of an AMP layer
class AMPLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = AMPAttention(config)
self.attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
self.output = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.activation = nn.GELU()
self.config = config
def forward(self, hidden_states, attention_mask=None):
residual = hidden_states
attention_output = self.attention(hidden_states, attention_mask)
attention_output = self.dropout(attention_output)
attention_output = residual + attention_output
attention_output = self.attn_layer_norm(attention_output)
residual = attention_output
intermediate_output = self.intermediate(attention_output)
intermediate_output = self.activation(intermediate_output)
layer_output = self.output(intermediate_output)
layer_output = self.dropout(layer_output)
layer_output = residual + layer_output
layer_output = self.final_layer_norm(layer_output)
return layer_output
AMP Attention
AMP Attention is a multi-head self-attention mechanism. In the attention mechanism, we use rotation position encoding to get the positional relationship between input tokens.
Figure 3: The structure of AMP Attention
The self-attention operation process of RoPE is as follows: for each word embedding vector in the token sequence, first calculate its corresponding query and key vectors, then calculate the corresponding rotation position encoding for each token position, then apply the rotation transformation to the elements of the query and key vectors at each token position in pairs, and finally calculate the inner product between the query and key to obtain the self-attention calculation result.
Figure 4: Rotation transformation demonstration of the RoPE algorithm
The codes of RoPE in AMP attention are as follows:
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=512):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
freqs_cis = precompute_freqs_cis(dim, max_seq_len)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, seq_len=None):
seq_len = seq_len if seq_len is not None else self.max_seq_len
return self.freqs_cis[:seq_len]
Decoder (Task-Aware Head)
The decoder is a Coupled Multilayer Perceptron(MLP) block containing six sub-blocks for each task. Except for the MIC regression task, the decoder of other tasks is a single MLP layer.
Figure 5: The structure and parameter configuration of the Decoder
ForMIC regression task, we combine the sequence feature with an organism feature. Thus, we design anorganism_encoder to to gain the organism featurea and an interactor to fuse the sequence feature and organism feature.
class ActivityRegressionHead(nn.Module):
def __init__(self, embedding_dim, num_organisms):
super().__init__()
self.organism_encoder = nn.Sequential(
nn.Embedding(num_organisms, embedding_dim),
nn.Dropout(0.1),
nn.Linear(embedding_dim, embedding_dim),
nn.GELU(),
nn.Dropout(0.1),
)
self.interactor = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim * 2),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(embedding_dim * 2, embedding_dim),
nn.GELU(),
nn.Dropout(0.2)
)
self.regressor = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim // 2),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(embedding_dim // 2, embedding_dim // 4),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(embedding_dim // 4, 1)
)
def forward(self, features, organism_ids):
org_features = self.organism_encoder(organism_ids)
combined = torch.cat([features, org_features], dim=1)
interaction = self.interactor(combined)
interaction += features # residual fusion with sequence features
return self.regressor(interaction)
And all decoders were stored in the AMPForMultiTask block as a dictionary.
self.task_heads = nn.ModuleDict()
for task_name, task_config in task_config.items():
num_orgs = task_config.get("num_organisms")
self.task_heads[task_name] = TaskAwareHead(
config,
str(task_name),
num_organisms=num_orgs
)
Tokenizer Configuration
The tokenizer we use in the AOMM is the same as Esm2 model, which you can find on HuggingFace. However, because the sequence length of antimicrobial peptides is generally between 1 and 100 AA. Thus we set the max_length to 128.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("muskwff/amp4multitask_124M")
{
"model_max_length": 128,
"tokenizer_class": "EsmTokenizer",
"special_tokens_map": {
"cls_token": "<cls>",
"pad_token": "<pad>",
"eos_token": "<eos>",
"unk_token": "<unk>",
"mask_token": "<mask>"
}
}
AMPOS and AMPPT
Overview
- AMP-Oriented Six tasks (AMPOS) is the data to train AOMM.
- The dataset is originated from SPADE and open source at Hugging Face.
- AMPOS dataset bring a deep learning challenge AMP-oriented Multi-Property Prediction Task (AMPPT) to the community.
The ownership of the SPADE Database and AMP_six_tasks dataset belongs to the iGEM team XJTLU_software 2025. If you have any questinons or suggestions, please contact us at igem@xjtlu.edu.cn(Please indicate your organization and purpose of the email).
{
"Sequence": {
"task_name": "label"
},
"KGGK": {
"amp_classification": 1.0,
"bioactivity_classification": [
"Anti-Gram+",
"Anti-Gram-",
"Antifungal",
"Anti-Mammalian Cell"
],
"hemolysis_regression": 0.24475088281037147,
"mic_regression": {
"Cryptococcus neoformans": 582.69,
"Bacillus megaterium": 776.92,
"Candida albicans": 4855.75,
"Micrococcus luteus": 3107.68,
"Acinetobacter baumannii": 4855.75,
"Escherichia coli": 1165.38,
"Pseudomonas aeruginosa": 2427.875,
"Staphylococcus aureus": 2427.875
}
}
}
Details for AMPOS and AMPPT
The length of the peptides in the dataset is between1 and 100.
amp_classification
| Label | Amount | Source |
|---|---|---|
| AMP | 42,542 | SPADE Database |
| Non-AMP | 49,513 | Uniprot (keywords search) |
bioactivity_classification
We unified labels from the SPADE database. In total, there are 68 labels. The list below shows the ids and the amount of the tags whose amount is greater than 100:
| Label | ID | Amount |
|---|---|---|
| Anti-Gram+ | 0 | 11,078 |
| Anti-Gram- | 1 | 11,277 |
| Anti-MRSA | 2 | 251 |
| Anti-Mammalian Cell | 3 | 4,857 |
| Anti-bacterial | 4 | 7,654 |
| Anti-biofilm | 5 | 243 |
| Anti-cancer | 6 | 4,676 |
| Anti-fungal | 7 | 5,876 |
| Antimicrobial | 8 | 23,622 |
| Anti-parasitic | 9 | 361 |
| Anti-tumor | 10 | 124 |
| Anti-viral | 11 | 3,599 |
| Insecticidal | 12 | 305 |
There are 35282 AMPs with bioactivity tags. Every AMP contains at least one tag.
Preprocessing before training
We retain the 13 tags from the above list and convert the label to a one-dimensional tensor of length 13.
if 'bioactivity_classification' in label:
multihot = [0] * len(bioactivity_projection)
for l in label['bioactivity_classification']:
if l in bioactivity_projection:
multihot[bioactivity_projection[l]] = 1
label['bioactivity_classification'] = torch.tensor(multihot, dtype=torch.float)
Example:
["Anti-Gram+"]
--->
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
half_life_regression
Unit: min
We selected in vitro experimental data as labels and the hosts were all mammalian cells. The number of samples is 8497.
Preprocess for training
We performed -log10 on all values to reduce the range and use Z-Score normalization to normalize the data.
def calculate_parameters4half_life(labels):
values = []
for label in labels:
if 'half_life_regression' in label and label['half_life_regression'] is not None:
values.append(label['half_life_regression'])
assert len(values) > 0, "No half life values found"
converted = -np.log10(np.array(values))
mean = np.mean(converted)
std = np.std(converted)
return mean, std
Figure 1: Half life distribution after log transformation
hemolysis_regression
We convert the lethality and lethal concentration data of antimicrobial peptides into a hemolytic activity score ranging from 0 to 1.
The formulas are as follows:
E_clipped = max(1, min(99, E))
EC50 = C * (100 - E_clipped) / E_clipped
Activity = -log10(EC50)
MeanActivity = average(Activity)
Activity_clipped = max(-5, min(-1, MeanActivity))
Score = 0.1 + 0.8 * (Activity_clipped - (-5)) / ((-1) - (-5))
which E is the lethality(%), C is the lethal concentration(μg/ml) and N is the number of samples. The closer the score is to 1, the stronger the hemolytic activity is.
-
Distribution:
Figure 2: Hemolytic activity score distribution
-
Concentration vs score:
Figure 3: Score vs lethal concentration(μg/ml)
Preprocessing before training
We performed ln(x+1) on all values to reduce the range and use Z-Score normalization to normalize the data.
def calculate_parameters4hemolysis(labels):
values = []
for label in labels:
if 'hemolysis_regression' in label and label['hemolysis_regression'] is not None:
values.append(label['hemolysis_regression'])
assert len(values) > 0, "No hemolysis values found"
converted = np.log1p(values)
mean = np.mean(converted)
std = np.std(converted)
return mean, std
mic_regression
Unit: μg/ml
There are 5824 AMPs containing mic_regression labels. An antimicrobial peptide may have multiple target organisms and their corresponding MIC values. We set the high threshold 100000 μg/ml and the low threshold 0.001 μg/ml to clip the original MIC values. We retained microorganisms that appeared more than 100 times. The list below shows the organisms that appears more than 500 times:
| Organism | ID | Samples |
|---|---|---|
| Acinetobacter baumannii | 0 | 711 |
| Bacillus subtilis | 1 | 1,515 |
| Candida albicans | 2 | 1,314 |
| Enterococcus faecalis | 3 | 681 |
| Escherichia coli | 4 | 4,461 |
| Klebsiella pneumoniae | 5 | 956 |
| Micrococcus luteus | 6 | 570 |
| Pseudomonas aeruginosa | 7 | 2,551 |
| Salmonella enterica | 8 | 966 |
| Staphylococcus aureus | 9 | 3,850 |
| Staphylococcus epidermidis | 10 | 1,097 |
Data distribution:
Figure 4: Overall MIC distribution
Preprocess for training
We retrain the organisms in the list above. The number of samples is 18,668.
We performed -log10 on all values and applied strain-specific normalization.
def calculate_mic_stats_by_organism(labels, organism_projection):
idx_to_org = {idx: org for org, idx in organism_projection.items()}
mic_values = {org_idx: [] for org_idx in idx_to_org.keys()}
for label in labels:
if 'mic_regression' in label:
for org, value in label['mic_regression'].items():
if org in organism_projection and value > 0:
org_idx = organism_projection[org]
log_val = -math.log10(value)
mic_values[org_idx].append(log_val)
# calculate average value and mean standard deviation for each organism
mic_stats = {}
for org_idx, values in mic_values.items():
if values: # make sure there are values
mean = np.mean(values)
std = np.std(values)
mic_stats[org_idx] = (mean, std)
return mic_stats
How to use data to train model
Run merged_dataloader.py to generate data loader for a certain task:
if __name__ == '__main__':
file_path = "merged_data.json"
train_loader, val_loader, test_loader = build_dataloader(file_path, "bioactivity_classification", 64)
use function denormalize and the file normalization_parameters.csv to denormalize the output of the model for regression tasks
norm_params_path = "normalization_parameters.csv" # change this to your path
norm_params = pd.read_csv(norm_params_path)
def denormalize(task_name, normalized_value, organism_id=None):
"""
Denormalize the value for regression tasks
"""
if isinstance(organism_id, torch.Tensor):
organism_id = organism_id.cpu().item()
if task_name == "mic_regression":
org_name = norm_params.iloc[organism_id]['organism']
params = norm_params[
(norm_params['parameter_type'] == 'mic_regression') &
(norm_params['organism'] == org_name)
]
mean = params['mean'].values[0]
std = params['std'].values[0]
log_val = normalized_value * std + mean
return 10 ** (-log_val) # 10^(-log10(value)) = value
elif task_name == "half_life_regression":
params = norm_params[norm_params['parameter_type'] == 'half_life_regression']
mean = params['mean'].values[0]
std = params['std'].values[0]
log_val = normalized_value * std + mean
return 10 ** (-log_val)
elif task_name == "hemolysis_regression":
params = norm_params[norm_params['parameter_type'] == 'hemolysis_regression']
mean = params['mean'].values[0]
std = params['std'].values[0]
log_val = normalized_value * std + mean
return np.expm1(log_val) # e^(log_val) - 1
return normalized_value # the rest of tasks are not normalized
Training
Overview
During the whole training process, we use AdamW as the optimizer and initialize the epoch as 100.
- We use the Early Stopping to stop the training when the validation loss doesn't decrease for 5 epochs to avoid overfitting.
- We use gradient clipping to increase the stability of training.
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
We save the model parameters with the best validation loss.
- AOMM(download from huggingface)\
- config.json
- model.safetensors
- modeling_amp.py
- normalization_parameters.py
- pytorch_model.bin
- README.md
- tokenizer_config.json
- vocab.txt
- merged_data.json
- merged_dataloader.py
- pretrain.py
- Trainer.py
Loss functions for training
Loss functions measure the difference between a model's predictions and the true target values, guiding parameter updates during network training. Below are four commonly used loss functions in PyTorch (denoted by their nn module names) with mathematical explanations:
nn.CrossEntropyLoss
- Core Principle
Designed for multi-class classification tasks (where each sample belongs to exactly one class). It implicitly combines nn.LogSoftmax (to convert logits to normalized log-probabilities) and nn.NLLLoss (Negative Log-Likelihood Loss) to avoid numerical instability from separate softmax calculations. - Mathematical Expression
For a single sample with:
Logits (model outputs before softmax): \( z = [z_1, z_2, ..., z_C] \) (where \( C \) is the number of classes),
True class label: \( y \) (an integer in \( [0, C-1] \), not one-hot encoded),
the loss is calculated as:
\[ \text{CrossEntropyLoss}(z, y) = -\log\left( \frac{e^{z_y}}{\sum_{i=1}^C e^{z_i}} \right) = -z_y + \log\left( \sum_{i=1}^C e^{z_i} \right) \]
For a batch of \( N \) samples, the average loss is:
\[ \text{CrossEntropyLoss}(Z, Y) = \frac{1}{N} \sum_{n=1}^N \left[ -Z_{n,Y_n} + \log\left( \sum_{i=1}^C e^{Z_{n,i}} \right) \right] \]
where \( Z \in \mathbb{R}^{N \times C} \) is the batch of logits, and \( Y = [y_1, y_2, ..., y_N] \) is the batch of true labels. - Application Scenarios
Image classification (e.g., ImageNet 1000-class tasks), text category classification (e.g., news topic labeling).
nn.BCEWithLogitsLoss
- Core Principle
Used for binary classification (two classes) or multi-label classification (each sample can belong to multiple classes). It integrates nn.Sigmoid (to map logits to probabilities in \( [0,1] \)) and nn.BCELoss (Binary Cross-Entropy Loss) to stabilize training. - Mathematical Expression
For a single sample with:
Logit (model output for one class): \( z \),
True label: \( y \) (0 or 1, since it's binary per class), the loss is:
\[ \text{BCEWithLogitsLoss}(z, y) = -\big[y \cdot \log(\sigma(z)) + (1-y) \cdot \log(1-\sigma(z))\big] \]
where \(\sigma(z) = \frac{1}{1+e^{-z}}\) is the Sigmoid function (converts logits to probabilities).
For a batch of \( N \) samples and \( K \) classes (multi-label scenario), the average loss is:
\[ \text{BCEWithLogitsLoss}(Z, Y) = \frac{1}{N \cdot K} \sum_{n=1}^N \sum_{k=1}^K \Big[ -Y_{n,k} \cdot \log(\sigma(Z_{n,k})) - (1-Y_{n,k}) \cdot \log(1-\sigma(Z_{n,k})) \Big] \]
where \( Z \in \mathbb{R}^{N \times K} \) is the batch of logits, and \( Y \in \{0,1\}^{N \times K} \) is the batch of true multi-label matrices. - Application Scenarios
Binary tasks (e.g., spam detection, medical disease diagnosis), multi-label tasks (e.g., image tagging: a photo may contain "cat" and "tree" simultaneously).
nn.MSELoss
- Core Principle
Short for Mean Squared Error Loss, it measures the average of the squared differences between predictions and true values. Primarily used for regression tasks (predicting continuous values) due to its smooth gradient. - Mathematical Expression
For a single sample with:
Model prediction: \( \hat{y} \),
True target value: \( y \) (a continuous number, e.g., house price, temperature),
the loss is:
\[ \text{MSELoss}(\hat{y}, y) = (\hat{y} - y)^2 \]
For a batch of \( N \) samples, the average loss is:
\[ \text{MSELoss}(\hat{Y}, Y) = \frac{1}{N} \sum_{n=1}^N (\hat{Y}_n - Y_n)^2 \]
where \( \hat{Y} = [\hat{y}_1, \hat{y}_2, ..., \hat{y}_N] \) is the batch of predictions, and \( Y = [y_1, y_2, ..., y_N] \) is the batch of true values. - Application Scenarios
Continuous value prediction (e.g., house price regression, time-series temperature forecasting, image super-resolution). Note: MSE is sensitive to outliers (large errors are amplified by squaring).
nn.HuberLoss
- Core Principle
A robust regression loss that balances the smoothness of MSE and the outlier resistance of MAE (Mean Absolute Error). It uses a hyperparameter \( \delta > 0 \) to switch between quadratic (MSE-like) and linear (MAE-like) loss:
When the absolute error is small (\( |\hat{y} - y| \leq \delta \)): Uses MSE (smooth gradient for stable training).
When the absolute error is large (\( |\hat{y} - y| > \delta \)): Uses MAE (avoids amplifying outlier errors). - Mathematical Expression
For a single sample:
\[ \text{HuberLoss}(\hat{y}, y; \delta) = \begin{cases} \frac{1}{2} (\hat{y} - y)^2 & \text{if } |\hat{y} - y| \leq \delta, \\ \delta \cdot |\hat{y} - y| - \frac{1}{2} \delta^2 & \text{if } |\hat{y} - y| > \delta. \end{cases} \]
For a batch of \( N \) samples, the average loss is:
\[ \text{HuberLoss}(\hat{Y}, Y; \delta) = \frac{1}{N} \sum_{n=1}^N \begin{cases} \frac{1}{2} (\hat{Y}_n - Y_n)^2 & \text{if } |\hat{Y}_n - Y_n| \leq \delta, \\ \delta \cdot |\hat{Y}_n - Y_n| - \frac{1}{2} \delta^2 & \text{if } |\hat{Y}_n - Y_n| > \delta. \end{cases} \] - Application Scenarios
Regression tasks with outliers (e.g., sensor data prediction with noise, autonomous driving speed estimation), where MSE would be distorted by extreme values. The choice of \( \delta \) depends on the data: smaller \( \delta \) makes it more like MSE, while larger \( \delta \) makes it more like MAE.
Pretrain
Run pretrain.py to pretrain the model using Masked Language Modeling (MLM) for AMP-Oriented Multi-task Model (AOMM).
Pretrain Details
Learning rate: 5e-5
We sample a few tokens in each sequence for MLM training.
- 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
- The rest of the time (10% of the time) we keep the masked input tokens unchanged
def mask_tokens(self, inputs):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mask_prob`)
special_tokens_mask = torch.zeros_like(inputs, dtype=torch.bool)
for token_id in self.special_token_ids:
special_tokens_mask |= (inputs == token_id)
probability_matrix = torch.full(labels.shape, self.mask_prob)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.prob_replace_mask)).bool() & masked_indices
inputs[indices_replaced] = self.mask_token_id
# 10% of the time, we replace masked input tokens with random word
current_prob = self.prob_replace_rand / (1 - self.prob_replace_mask)
indices_random = torch.bernoulli(torch.full(labels.shape, current_prob)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
for token_id in self.special_token_ids:
indices_random = indices_random & (inputs != token_id)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
The loss function is CrossEntropy Loss.
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
Multi-task Training
Run Trainer.py to perform multi-task training. The first 5 layers of the encoder are frozen during this stage.
Multi-task Training Details
- Learning rate: 2e-5
- During task switching, initialize the model with the best model from the previous task
def freeze_encoder_layers(self, num_frozen_layers: int=5):
"""freeze some certain layers of the encoder"""
assert num_frozen_layers <= len(self.model.amp.layers) and num_frozen_layers > 0, "Invalid number of frozen layers"
# freeze the word embeddings
for param in self.model.amp.word_embeddings.parameters():
param.requires_grad = False
# freeze the layer norm layer
for param in self.model.amp.emb_layer_norm.parameters():
param.requires_grad = False
# freeze some layers of the encoder
for i, layer in enumerate(self.model.amp.layers):
if i < num_frozen_layers:
for param in layer.parameters():
param.requires_grad = False
During task switching, we initialize the model to the best model of the previous task.
def train_all_tasks(self, epochs_per_task=100, start_from=None):
if start_from:
start_idx = self.task_order.index(start_from)
tasks_to_train = self.task_order[start_idx:]
else:
tasks_to_train = self.task_order
for i, task in enumerate(tasks_to_train):
print(f"\n=== Starting Task {i+1}/{len(tasks_to_train)}: {task} ===")
if i > 0:
prev_task = self.task_order[i-1]
checkpoint_path = os.path.join(self.save_dir, f"best_{prev_task}_model.pth")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=True)
self.model.load_state_dict(checkpoint)
self.train_task(task, epochs=epochs_per_task)
print(f"Completed training for {task}\n")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
Training Task Order
- amp_classification
- bioactivity_classification
- half_life_regression
- hemolysis_regression
- mic_regression
| Task | Loss Function | Best Validation Loss | Epoch (Early Stop) |
|---|---|---|---|
| amp_classification | CrossEntropyLoss() | 0.1079 | 6 |
| bioactivity_classification | BCEWithLogitsLoss() | 0.2076 | 11 |
| half_life_regression | MSELoss() | 0.0118 | 16 |
| hemolysis_regression | MSELoss() | 0.3841 | 30 |
| mic_regression | HuberLoss(delta=1.0) | 0.0863 | 30 |