seed_everything()
Train DL
Setup
Utils
seed_everything
seed_everything (seed=123)
def_device
'cpu'
Load Data
# read training data
= pd.read_parquet('https://github.com/sky1ove/katlas_raw/raw/refs/heads/main/nbs/raw/combine_t5_kd.parquet').reset_index()
df
# read data contains info for split
= Data.get_kinase_info().query('pseudo!="1"') # get non-pseudo kinase
info_df
# merge info with training data
= df[['kinase']].merge(info_df)
info
info.head()
# splits
= get_splits(info,stratified='group')
splits = splits[0]
split0
# column name of feature and target
= df.columns[df.columns.str.startswith('T5_')]
feat_col = df.columns[~df.columns.isin(feat_col)][1:] target_col
StratifiedKFold(n_splits=5, random_state=123, shuffle=True)
# kinase group in train set: 9
# kinase group in test set: 9
---------------------------
# kinase in train set: 312
---------------------------
# kinase in test set: 78
---------------------------
test set: ['EPHA3' 'FES' 'FLT3' 'FYN' 'EPHB1' 'EPHB3' 'FER' 'EPHB4' 'FLT4' 'FGFR1'
'EPHA5' 'TEK' 'DDR2' 'ZAP70' 'LIMK1' 'ULK3' 'JAK1' 'WEE1' 'TESK1'
'MAP2K3' 'AMPKA2' 'ATM' 'CAMK1D' 'CAMK2D' 'CAMK4' 'CAMKK1' 'CK1D' 'CK1E'
'DYRK2' 'DYRK4' 'HGK' 'IKKE' 'JNK2' 'JNK3' 'KHS1' 'MAPKAPK5' 'MEK2'
'MSK2' 'NDR1' 'NEK6' 'NEK9' 'NIM1' 'NLK' 'OSR1' 'P38A' 'P38B' 'P90RSK'
'PAK1' 'PERK' 'PKCH' 'PKCI' 'PKN1' 'ROCK2' 'RSK2' 'SIK' 'STLK3' 'TAK1'
'TSSK1' 'ALPHAK3' 'BMPR2' 'CDK10' 'CDK13' 'CDK14' 'CDKL5' 'GCN2' 'GRK4'
'IRE1' 'KHS2' 'MASTL' 'MLK4' 'MNK1' 'MRCKA' 'PRPK' 'QSK' 'SMMLCK' 'SSTK'
'ULK2' 'VRK1']
Dataset
GeneralDataset
GeneralDataset (df, feat_col, target_col=None)
A general dataset that can be applied to any dataframe
Type | Default | Details | |
---|---|---|---|
df | a dataframe of values | ||
feat_col | feature columns | ||
target_col | NoneType | None | Will return test set for prediction if target col is None |
# dataset
= GeneralDataset(df,feat_col,target_col) ds
len(ds)
390
= DataLoader(ds, batch_size=64, shuffle=True) dl
get_sampler
get_sampler (info, col)
For imbalanced data, get higher weights for less-represented samples
= get_sampler(info,'subfamily') sampler
# dataloader
= DataLoader(ds, batch_size=64, sampler=sampler) dl
= next(iter(dl))
xb,yb
xb.shape,yb.shape
(torch.Size([64, 1024]), torch.Size([64, 210]))
Models
MLP
MLP_1
MLP_1 (num_features, num_targets, hidden_units=[512, 218], dp=0.2)
= len(feat_col)
n_feature = len(target_col) n_target
= MLP_1(n_feature, n_target) model
model(xb)
tensor([[-0.1115, -0.3755, -0.3818, ..., -0.1483, -0.0387, -0.1111],
[ 0.8555, 0.9352, -0.9642, ..., -0.4723, 0.7757, -0.0121],
[ 0.3422, 0.3537, -0.1441, ..., 0.5467, -0.4535, 0.2103],
...,
[-0.4287, 0.6751, 0.1797, ..., 0.0192, 0.0692, -0.0573],
[-0.0206, -0.1953, 0.7445, ..., -0.2206, -0.1188, 0.4579],
[ 0.2342, -0.0243, 0.4630, ..., 0.8393, 0.5747, -0.6881]],
grad_fn=<AddmmBackward0>)
CNN1D
Version 1
CNN1D_1
CNN1D_1 (num_features, num_targets)
Same as nn.Module
, but no need for subclasses to call super().__init__
Details | |
---|---|
num_features | this does not matter, just for format |
num_targets |
= CNN1D_1(n_feature, n_target) model
model(xb)
tensor([[ 0.0193, 0.0690, 0.0138, ..., -0.0428, -0.0026, 0.0840],
[ 0.0203, 0.0693, 0.0136, ..., -0.0422, -0.0023, 0.0846],
[ 0.0198, 0.0703, 0.0148, ..., -0.0424, -0.0029, 0.0839],
...,
[ 0.0197, 0.0694, 0.0147, ..., -0.0429, -0.0019, 0.0841],
[ 0.0193, 0.0687, 0.0146, ..., -0.0429, -0.0017, 0.0843],
[ 0.0191, 0.0692, 0.0148, ..., -0.0425, -0.0028, 0.0834]],
grad_fn=<AddmmBackward0>)
Version 2
init_weights
init_weights (m, leaky=0.0)
Initiate any Conv layer with Kaiming norm.
lin_wn
lin_wn (ni, nf, dp=0.1, act=<class 'torch.nn.modules.activation.SiLU'>)
Weight norm of linear.
conv_wn
conv_wn (ni, nf, ks=3, stride=1, padding=1, dp=0.1, act=<class 'torch.nn.modules.activation.ReLU'>)
Weight norm of conv.
CNN1D_2
CNN1D_2 (ni, nf, amp_scale=16)
*Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*
= CNN1D_2(n_feature,n_target).apply(init_weights) model
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
model(xb)
tensor([[-0.5740, 0.0151, -0.0819, ..., 0.2636, 0.3405, -0.1404],
[-0.6800, 0.5530, -0.0958, ..., -0.3752, -0.6124, 0.7171],
[ 0.4427, -0.3204, -0.3243, ..., -0.2290, 0.1070, 0.1504],
...,
[-0.3660, -0.2667, -0.6036, ..., -0.3130, 0.5462, -0.0055],
[ 0.4511, 0.6824, 0.8659, ..., -0.0171, 0.2362, -0.3475],
[-0.0746, -0.1699, 0.6895, ..., 1.1522, -0.3472, 0.6422]],
grad_fn=<AddmmBackward0>)
DL Trainer
train_dl
train_dl (df, feat_col, target_col, split, model_func, n_epoch=4, bs=32, lr=0.01, loss=<function mse>, save=None, sampler=None, lr_find=False)
A DL trainer.
Type | Default | Details | |
---|---|---|---|
df | |||
feat_col | |||
target_col | |||
split | tuple of numpy array for split index | ||
model_func | function to get pytorch model | ||
n_epoch | int | 4 | number of epochs |
bs | int | 32 | batch size |
lr | float | 0.01 | will be useless if lr_find is True |
loss | function | mse | loss function |
save | NoneType | None | models/{save}.pth |
sampler | NoneType | None | |
lr_find | bool | False | if true, will use lr from lr_find |
def get_model():
return CNN1D_2(n_feature, n_target)
= train_dl(df,
target, pred
feat_col,
target_col,
split0,
get_model,=1,
n_epoch= 1e-2,
lr = 'test') save
lr in training is 0.01
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.937994 | 1.231321 | 0.106572 | 0.064782 | 00:01 |
score_each(target,pred)
overall MSE: 1.2313
Average Pearson: 0.1580
(1.2313209,
0.1579942852920301,
Pearson
3 -0.037969
8 -0.045367
10 -0.057115
19 -0.044484
24 -0.059326
.. ...
359 0.247093
361 -0.147023
366 0.107366
367 -0.004609
373 0.260843
[78 rows x 1 columns])
DL CV
train_dl_cv
train_dl_cv (df, feat_col, target_col, splits, model_func, save:str=None, n_epoch=4, bs=32, lr=0.01, loss=<function mse>, sampler=None, lr_find=False)
Type | Default | Details | |
---|---|---|---|
df | |||
feat_col | |||
target_col | |||
splits | list of tuples | ||
model_func | functions like lambda x: return MLP_1(num_feat, num_target) | ||
save | NoneType | None | models/{save}.pth |
n_epoch | int | 4 | number of epochs |
bs | int | 32 | batch size |
lr | float | 0.01 | will be useless if lr_find is True |
loss | function | mse | loss function |
sampler | NoneType | None | |
lr_find | bool | False | if true, will use lr from lr_find |
def get_model():
return CNN1D_2(n_feature, n_target)
= train_dl_cv(df,feat_col,target_col,splits,get_model,n_epoch=1,lr=3e-3) oof,metrics
------fold0------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.157886 | 0.985997 | 0.123256 | 0.076049 | 00:01 |
overall MSE: 0.9860
Average Pearson: 0.2104
------fold1------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.194019 | 0.984086 | 0.130521 | 0.092631 | 00:01 |
overall MSE: 0.9841
Average Pearson: 0.1521
------fold2------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.154190 | 0.988698 | 0.114616 | 0.064543 | 00:01 |
overall MSE: 0.9887
Average Pearson: 0.2677
------fold3------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.177719 | 0.975760 | 0.156240 | 0.125270 | 00:01 |
overall MSE: 0.9758
Average Pearson: 0.1862
------fold4------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.170352 | 0.983774 | 0.135724 | 0.102884 | 00:01 |
overall MSE: 0.9838
Average Pearson: 0.2547
metrics
fold | mse | pearson_avg | |
---|---|---|---|
0 | 0 | 0.985997 | 0.210434 |
1 | 1 | 0.984086 | 0.152108 |
2 | 2 | 0.988698 | 0.267718 |
3 | 3 | 0.975760 | 0.186243 |
4 | 4 | 0.983774 | 0.254678 |
metrics.pearson_avg.mean()
0.2142360341441809
= df[target_col]
target = score_each(target,oof) _,_,corr
overall MSE: 0.9837
Average Pearson: 0.2142
corr
Pearson | |
---|---|
0 | -0.130073 |
1 | -0.210985 |
2 | -0.251176 |
3 | -0.163586 |
4 | -0.054230 |
... | ... |
385 | 0.186961 |
386 | 0.336823 |
387 | -0.029309 |
388 | 0.029152 |
389 | 0.440186 |
390 rows × 1 columns
DL Predict
predict_dl
predict_dl (df, feat_col, target_col, model, model_pth)
Predict dataframe given a deep learning model
Details | |
---|---|
df | |
feat_col | |
target_col | |
model | model architecture |
model_pth | only name, not with .pth |
= df.loc[split0[1]] test
= predict_dl(test.head(3),
pred
feat_col,
target_col, 'test')
model, pred
-5P | -5G | -5A | -5C | -5S | -5T | -5V | -5I | -5L | -5M | ... | 4Q | 4N | 4D | 4E | 4s | 4t | 4y | 0s | 0t | 0y | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
3 | -0.355736 | -0.371514 | -0.814290 | -0.573196 | 1.173128 | 0.863355 | -0.738675 | 0.366732 | 0.702826 | 1.137061 | ... | -0.286256 | -0.485139 | 0.012519 | 0.306050 | -0.215432 | 0.959879 | 0.697288 | 0.805923 | 0.199813 | -0.397550 |
8 | -0.322134 | -0.418208 | -0.835480 | -0.620543 | 1.237427 | 0.920613 | -0.783180 | 0.385320 | 0.750605 | 1.191602 | ... | -0.300988 | -0.502365 | -0.041983 | 0.355885 | -0.245259 | 1.002593 | 0.699817 | 0.878045 | 0.185331 | -0.444951 |
10 | -0.349195 | -0.373037 | -0.816447 | -0.580855 | 1.182283 | 0.871513 | -0.739842 | 0.368429 | 0.705171 | 1.146074 | ... | -0.286629 | -0.483317 | 0.007166 | 0.306860 | -0.217535 | 0.966914 | 0.695795 | 0.811916 | 0.201817 | -0.403425 |
3 rows × 210 columns
= score_each(test[target_col].head(3),pred) _,_,corr
overall MSE: 1.4146
Average Pearson: 0.0468