Train DL

A collection of deep learning tools via Fastai

Setup

Utils


source

seed_everything

 seed_everything (seed=123)
seed_everything()
def_device
'cpu'

Load Data

# read training data
df = pd.read_parquet('https://github.com/sky1ove/katlas_raw/raw/refs/heads/main/nbs/raw/combine_t5_kd.parquet').reset_index()

# read data contains info for split
info_df = Data.get_kinase_info().query('pseudo!="1"') # get non-pseudo kinase

# merge info with training data
info = df[['kinase']].merge(info_df)
info.head()

# splits
splits = get_splits(info,stratified='group')
split0 = splits[0]


# column name of feature and target
feat_col = df.columns[df.columns.str.startswith('T5_')]
target_col = df.columns[~df.columns.isin(feat_col)][1:]
StratifiedKFold(n_splits=5, random_state=123, shuffle=True)
# kinase group in train set: 9
# kinase group in test set: 9
---------------------------
# kinase in train set: 312
---------------------------
# kinase in test set: 78
---------------------------
test set: ['EPHA3' 'FES' 'FLT3' 'FYN' 'EPHB1' 'EPHB3' 'FER' 'EPHB4' 'FLT4' 'FGFR1'
 'EPHA5' 'TEK' 'DDR2' 'ZAP70' 'LIMK1' 'ULK3' 'JAK1' 'WEE1' 'TESK1'
 'MAP2K3' 'AMPKA2' 'ATM' 'CAMK1D' 'CAMK2D' 'CAMK4' 'CAMKK1' 'CK1D' 'CK1E'
 'DYRK2' 'DYRK4' 'HGK' 'IKKE' 'JNK2' 'JNK3' 'KHS1' 'MAPKAPK5' 'MEK2'
 'MSK2' 'NDR1' 'NEK6' 'NEK9' 'NIM1' 'NLK' 'OSR1' 'P38A' 'P38B' 'P90RSK'
 'PAK1' 'PERK' 'PKCH' 'PKCI' 'PKN1' 'ROCK2' 'RSK2' 'SIK' 'STLK3' 'TAK1'
 'TSSK1' 'ALPHAK3' 'BMPR2' 'CDK10' 'CDK13' 'CDK14' 'CDKL5' 'GCN2' 'GRK4'
 'IRE1' 'KHS2' 'MASTL' 'MLK4' 'MNK1' 'MRCKA' 'PRPK' 'QSK' 'SMMLCK' 'SSTK'
 'ULK2' 'VRK1']

Dataset


source

GeneralDataset

 GeneralDataset (df, feat_col, target_col=None)

A general dataset that can be applied to any dataframe

Type Default Details
df a dataframe of values
feat_col feature columns
target_col NoneType None Will return test set for prediction if target col is None
# dataset
ds = GeneralDataset(df,feat_col,target_col)
len(ds)
390
dl = DataLoader(ds, batch_size=64, shuffle=True)

source

get_sampler

 get_sampler (info, col)

For imbalanced data, get higher weights for less-represented samples

sampler = get_sampler(info,'subfamily')
# dataloader
dl = DataLoader(ds, batch_size=64, sampler=sampler)
xb,yb = next(iter(dl))

xb.shape,yb.shape
(torch.Size([64, 1024]), torch.Size([64, 210]))

Models

MLP


source

MLP_1

 MLP_1 (num_features, num_targets, hidden_units=[512, 218], dp=0.2)
n_feature = len(feat_col)
n_target = len(target_col)
model = MLP_1(n_feature, n_target)
model(xb)
tensor([[-0.1115, -0.3755, -0.3818,  ..., -0.1483, -0.0387, -0.1111],
        [ 0.8555,  0.9352, -0.9642,  ..., -0.4723,  0.7757, -0.0121],
        [ 0.3422,  0.3537, -0.1441,  ...,  0.5467, -0.4535,  0.2103],
        ...,
        [-0.4287,  0.6751,  0.1797,  ...,  0.0192,  0.0692, -0.0573],
        [-0.0206, -0.1953,  0.7445,  ..., -0.2206, -0.1188,  0.4579],
        [ 0.2342, -0.0243,  0.4630,  ...,  0.8393,  0.5747, -0.6881]],
       grad_fn=<AddmmBackward0>)

CNN1D

Version 1


source

CNN1D_1

 CNN1D_1 (num_features, num_targets)

Same as nn.Module, but no need for subclasses to call super().__init__

Details
num_features this does not matter, just for format
num_targets
model = CNN1D_1(n_feature, n_target)
model(xb)
tensor([[ 0.0193,  0.0690,  0.0138,  ..., -0.0428, -0.0026,  0.0840],
        [ 0.0203,  0.0693,  0.0136,  ..., -0.0422, -0.0023,  0.0846],
        [ 0.0198,  0.0703,  0.0148,  ..., -0.0424, -0.0029,  0.0839],
        ...,
        [ 0.0197,  0.0694,  0.0147,  ..., -0.0429, -0.0019,  0.0841],
        [ 0.0193,  0.0687,  0.0146,  ..., -0.0429, -0.0017,  0.0843],
        [ 0.0191,  0.0692,  0.0148,  ..., -0.0425, -0.0028,  0.0834]],
       grad_fn=<AddmmBackward0>)

Version 2


source

init_weights

 init_weights (m, leaky=0.0)

Initiate any Conv layer with Kaiming norm.


source

lin_wn

 lin_wn (ni, nf, dp=0.1, act=<class 'torch.nn.modules.activation.SiLU'>)

Weight norm of linear.


source

conv_wn

 conv_wn (ni, nf, ks=3, stride=1, padding=1, dp=0.1, act=<class
          'torch.nn.modules.activation.ReLU'>)

Weight norm of conv.


source

CNN1D_2

 CNN1D_2 (ni, nf, amp_scale=16)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*

model = CNN1D_2(n_feature,n_target).apply(init_weights)
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
model(xb)
tensor([[-0.5740,  0.0151, -0.0819,  ...,  0.2636,  0.3405, -0.1404],
        [-0.6800,  0.5530, -0.0958,  ..., -0.3752, -0.6124,  0.7171],
        [ 0.4427, -0.3204, -0.3243,  ..., -0.2290,  0.1070,  0.1504],
        ...,
        [-0.3660, -0.2667, -0.6036,  ..., -0.3130,  0.5462, -0.0055],
        [ 0.4511,  0.6824,  0.8659,  ..., -0.0171,  0.2362, -0.3475],
        [-0.0746, -0.1699,  0.6895,  ...,  1.1522, -0.3472,  0.6422]],
       grad_fn=<AddmmBackward0>)

DL Trainer


source

train_dl

 train_dl (df, feat_col, target_col, split, model_func, n_epoch=4, bs=32,
           lr=0.01, loss=<function mse>, save=None, sampler=None,
           lr_find=False)

A DL trainer.

Type Default Details
df
feat_col
target_col
split tuple of numpy array for split index
model_func function to get pytorch model
n_epoch int 4 number of epochs
bs int 32 batch size
lr float 0.01 will be useless if lr_find is True
loss function mse loss function
save NoneType None models/{save}.pth
sampler NoneType None
lr_find bool False if true, will use lr from lr_find
def get_model():
    return CNN1D_2(n_feature, n_target)
target, pred = train_dl(df, 
                        feat_col, 
                        target_col,
                        split0, 
                        get_model,
                        n_epoch=1,
                        lr = 1e-2,
                        save = 'test')
lr in training is 0.01
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch train_loss valid_loss pearsonr spearmanr time
0 1.937994 1.231321 0.106572 0.064782 00:01
score_each(target,pred)
overall MSE: 1.2313
Average Pearson: 0.1580 
(1.2313209,
 0.1579942852920301,
       Pearson
 3   -0.037969
 8   -0.045367
 10  -0.057115
 19  -0.044484
 24  -0.059326
 ..        ...
 359  0.247093
 361 -0.147023
 366  0.107366
 367 -0.004609
 373  0.260843
 
 [78 rows x 1 columns])

DL CV


source

train_dl_cv

 train_dl_cv (df, feat_col, target_col, splits, model_func, save:str=None,
              n_epoch=4, bs=32, lr=0.01, loss=<function mse>,
              sampler=None, lr_find=False)
Type Default Details
df
feat_col
target_col
splits list of tuples
model_func functions like lambda x: return MLP_1(num_feat, num_target)
save NoneType None models/{save}.pth
n_epoch int 4 number of epochs
bs int 32 batch size
lr float 0.01 will be useless if lr_find is True
loss function mse loss function
sampler NoneType None
lr_find bool False if true, will use lr from lr_find
def get_model():
    return CNN1D_2(n_feature, n_target)
oof,metrics = train_dl_cv(df,feat_col,target_col,splits,get_model,n_epoch=1,lr=3e-3)
------fold0------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch train_loss valid_loss pearsonr spearmanr time
0 1.157886 0.985997 0.123256 0.076049 00:01
overall MSE: 0.9860
Average Pearson: 0.2104 
------fold1------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch train_loss valid_loss pearsonr spearmanr time
0 1.194019 0.984086 0.130521 0.092631 00:01
overall MSE: 0.9841
Average Pearson: 0.1521 
------fold2------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch train_loss valid_loss pearsonr spearmanr time
0 1.154190 0.988698 0.114616 0.064543 00:01
overall MSE: 0.9887
Average Pearson: 0.2677 
------fold3------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch train_loss valid_loss pearsonr spearmanr time
0 1.177719 0.975760 0.156240 0.125270 00:01
overall MSE: 0.9758
Average Pearson: 0.1862 
------fold4------
lr in training is 0.003
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
epoch train_loss valid_loss pearsonr spearmanr time
0 1.170352 0.983774 0.135724 0.102884 00:01
overall MSE: 0.9838
Average Pearson: 0.2547 
metrics
fold mse pearson_avg
0 0 0.985997 0.210434
1 1 0.984086 0.152108
2 2 0.988698 0.267718
3 3 0.975760 0.186243
4 4 0.983774 0.254678
metrics.pearson_avg.mean()
0.2142360341441809
target = df[target_col]
_,_,corr = score_each(target,oof)
overall MSE: 0.9837
Average Pearson: 0.2142 
corr
Pearson
0 -0.130073
1 -0.210985
2 -0.251176
3 -0.163586
4 -0.054230
... ...
385 0.186961
386 0.336823
387 -0.029309
388 0.029152
389 0.440186

390 rows × 1 columns

DL Predict


source

predict_dl

 predict_dl (df, feat_col, target_col, model, model_pth)

Predict dataframe given a deep learning model

Details
df
feat_col
target_col
model model architecture
model_pth only name, not with .pth
test = df.loc[split0[1]]
pred = predict_dl(test.head(3),
                  feat_col,
                  target_col, 
                  model,'test')
pred
-5P -5G -5A -5C -5S -5T -5V -5I -5L -5M ... 4Q 4N 4D 4E 4s 4t 4y 0s 0t 0y
3 -0.355736 -0.371514 -0.814290 -0.573196 1.173128 0.863355 -0.738675 0.366732 0.702826 1.137061 ... -0.286256 -0.485139 0.012519 0.306050 -0.215432 0.959879 0.697288 0.805923 0.199813 -0.397550
8 -0.322134 -0.418208 -0.835480 -0.620543 1.237427 0.920613 -0.783180 0.385320 0.750605 1.191602 ... -0.300988 -0.502365 -0.041983 0.355885 -0.245259 1.002593 0.699817 0.878045 0.185331 -0.444951
10 -0.349195 -0.373037 -0.816447 -0.580855 1.182283 0.871513 -0.739842 0.368429 0.705171 1.146074 ... -0.286629 -0.483317 0.007166 0.306860 -0.217535 0.966914 0.695795 0.811916 0.201817 -0.403425

3 rows × 210 columns

_,_,corr = score_each(test[target_col].head(3),pred)
overall MSE: 1.4146
Average Pearson: 0.0468