Train DL

A collection of deep learning tools via Fastai

Setup

Utils


source

seed_everything

 seed_everything (seed=123)
seed_everything()
def_device
'cpu'

Load Data

# read training data
df = pd.read_parquet('https://github.com/sky1ove/katlas_raw/raw/refs/heads/main/nbs/raw/combine_t5_kd.parquet').reset_index()

# read data contains info for split
info_df = Data.get_kinase_info().query('pseudo!="1"') # get non-pseudo kinase

# merge info with training data
info = df[['kinase']].merge(info_df)
info.head()

# splits
splits = get_splits(info,stratified='group')
split0 = splits[0]


# column name of feature and target
feat_col = df.columns[df.columns.str.startswith('T5_')]
target_col = df.columns[~df.columns.isin(feat_col)][1:]
StratifiedKFold(n_splits=5, random_state=123, shuffle=True)
# kinase group in train set: 9
# kinase group in test set: 9
---------------------------
# kinase in train set: 312
---------------------------
# kinase in test set: 78
---------------------------
test set: ['EPHA3' 'FES' 'FLT3' 'FYN' 'EPHB1' 'EPHB3' 'FER' 'EPHB4' 'FLT4' 'FGFR1' 'EPHA5' 'TEK' 'DDR2' 'ZAP70' 'LIMK1' 'ULK3' 'JAK1' 'WEE1' 'TESK1' 'MAP2K3' 'AMPKA2' 'ATM' 'CAMK1D' 'CAMK2D' 'CAMK4' 'CAMKK1'
 'CK1D' 'CK1E' 'DYRK2' 'DYRK4' 'HGK' 'IKKE' 'JNK2' 'JNK3' 'KHS1' 'MAPKAPK5' 'MEK2' 'MSK2' 'NDR1' 'NEK6' 'NEK9' 'NIM1' 'NLK' 'OSR1' 'P38A' 'P38B' 'P90RSK' 'PAK1' 'PERK' 'PKCH' 'PKCI' 'PKN1' 'ROCK2'
 'RSK2' 'SIK' 'STLK3' 'TAK1' 'TSSK1' 'ALPHAK3' 'BMPR2' 'CDK10' 'CDK13' 'CDK14' 'CDKL5' 'GCN2' 'GRK4' 'IRE1' 'KHS2' 'MASTL' 'MLK4' 'MNK1' 'MRCKA' 'PRPK' 'QSK' 'SMMLCK' 'SSTK' 'ULK2' 'VRK1']

Dataset


source

GeneralDataset

 GeneralDataset (df, feat_col, target_col=None)

A general dataset that can be applied to any dataframe

Type Default Details
df a dataframe of values
feat_col feature columns
target_col NoneType None Will return test set for prediction if target col is None
# dataset
ds = GeneralDataset(df,feat_col,target_col)
len(ds)
390
dl = DataLoader(ds, batch_size=64, shuffle=True)

source

get_sampler

 get_sampler (info, col)

For imbalanced data, get higher weights for less-represented samples

sampler = get_sampler(info,'subfamily')
# dataloader
dl = DataLoader(ds, batch_size=64, sampler=sampler)
xb,yb = next(iter(dl))

xb.shape,yb.shape
(torch.Size([64, 1024]), torch.Size([64, 210]))

Models

MLP


source

MLP_1

 MLP_1 (num_features, num_targets, hidden_units=[512, 218], dp=0.2)
n_feature = len(feat_col)
n_target = len(target_col)
model = MLP_1(n_feature, n_target)
model(xb)
tensor([[-0.1115, -0.3755, -0.3818,  ..., -0.1483, -0.0387, -0.1111],
        [ 0.8555,  0.9352, -0.9642,  ..., -0.4723,  0.7757, -0.0121],
        [ 0.3422,  0.3537, -0.1441,  ...,  0.5467, -0.4535,  0.2103],
        ...,
        [-0.4287,  0.6751,  0.1797,  ...,  0.0192,  0.0692, -0.0573],
        [-0.0206, -0.1953,  0.7445,  ..., -0.2206, -0.1188,  0.4579],
        [ 0.2342, -0.0243,  0.4630,  ...,  0.8393,  0.5747, -0.6881]], grad_fn=<AddmmBackward0>)

CNN1D

Version 1


source

CNN1D_1

 CNN1D_1 (num_features, num_targets)

Same as nn.Module, but no need for subclasses to call super().__init__

Details
num_features this does not matter, just for format
num_targets
model = CNN1D_1(n_feature, n_target)
model(xb)
tensor([[ 0.0193,  0.0690,  0.0138,  ..., -0.0428, -0.0026,  0.0840],
        [ 0.0203,  0.0693,  0.0136,  ..., -0.0422, -0.0023,  0.0846],
        [ 0.0198,  0.0703,  0.0148,  ..., -0.0424, -0.0029,  0.0839],
        ...,
        [ 0.0197,  0.0694,  0.0147,  ..., -0.0429, -0.0019,  0.0841],
        [ 0.0193,  0.0687,  0.0146,  ..., -0.0429, -0.0017,  0.0843],
        [ 0.0191,  0.0692,  0.0148,  ..., -0.0425, -0.0028,  0.0834]], grad_fn=<AddmmBackward0>)

Version 2


source

init_weights

 init_weights (m, leaky=0.0)

Initiate any Conv layer with Kaiming norm.


source

lin_wn

 lin_wn (ni, nf, dp=0.1, act=<class 'torch.nn.modules.activation.SiLU'>)

Weight norm of linear.


source

conv_wn

 conv_wn (ni, nf, ks=3, stride=1, padding=1, dp=0.1, act=<class
          'torch.nn.modules.activation.ReLU'>)

Weight norm of conv.


source

CNN1D_2

 CNN1D_2 (ni, nf, amp_scale=16)

*Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*

model = CNN1D_2(n_feature,n_target).apply(init_weights)
model(xb)
tensor([[-0.5740,  0.0151, -0.0819,  ...,  0.2636,  0.3405, -0.1404],
        [-0.6800,  0.5530, -0.0958,  ..., -0.3752, -0.6124,  0.7171],
        [ 0.4427, -0.3204, -0.3243,  ..., -0.2290,  0.1070,  0.1504],
        ...,
        [-0.3660, -0.2667, -0.6036,  ..., -0.3130,  0.5462, -0.0055],
        [ 0.4511,  0.6824,  0.8659,  ..., -0.0171,  0.2362, -0.3475],
        [-0.0746, -0.1699,  0.6895,  ...,  1.1522, -0.3472,  0.6422]], grad_fn=<AddmmBackward0>)

DL Trainer


source

train_dl

 train_dl (df, feat_col, target_col, split, model_func, n_epoch=4, bs=32,
           lr=0.01, loss=<function mse>, save=None, sampler=None,
           lr_find=False)

A DL trainer.

Type Default Details
df
feat_col
target_col
split tuple of numpy array for split index
model_func function to get pytorch model
n_epoch int 4 number of epochs
bs int 32 batch size
lr float 0.01 will be useless if lr_find is True
loss function mse loss function
save NoneType None models/{save}.pth
sampler NoneType None
lr_find bool False if true, will use lr from lr_find
def get_model():
    return CNN1D_2(n_feature, n_target)
target, pred = train_dl(df, 
                        feat_col, 
                        target_col,
                        split0, 
                        get_model,
                        n_epoch=1,
                        lr = 1e-2,
                        save = 'test')
lr in training is 0.01
epoch train_loss valid_loss pearsonr spearmanr time
0 2.194746 2.176958 -0.104578 -0.053141 00:04
score_each(target,pred)
overall MSE: 2.1770
Average Pearson: 0.2149 
(2.176958,
 0.21488270397776432,
       Pearson
 3   -0.442855
 8   -0.490345
 10  -0.401885
 19  -0.428557
 24  -0.383956
 ..        ...
 359 -0.127000
 361  0.005761
 366  0.095977
 367  0.335805
 373 -0.230842
 
 [78 rows x 1 columns])

DL CV


source

train_dl_cv

 train_dl_cv (df, feat_col, target_col, splits, model_func, save:str=None,
              n_epoch=4, bs=32, lr=0.01, loss=<function mse>,
              sampler=None, lr_find=False)
Type Default Details
df
feat_col
target_col
splits list of tuples
model_func functions like lambda x: return MLP_1(num_feat, num_target)
save NoneType None models/{save}.pth
n_epoch int 4 number of epochs
bs int 32 batch size
lr float 0.01 will be useless if lr_find is True
loss function mse loss function
sampler NoneType None
lr_find bool False if true, will use lr from lr_find
def get_model():
    return CNN1D_2(n_feature, n_target)
oof,metrics = train_dl_cv(df,feat_col,target_col,splits,get_model,n_epoch=1,lr=3e-3)
------fold0------
lr in training is 0.003
epoch train_loss valid_loss pearsonr spearmanr time
0 1.165076 0.997911 0.091948 0.058285 00:01
overall MSE: 0.9979
Average Pearson: 0.1634 
------fold1------
lr in training is 0.003
epoch train_loss valid_loss pearsonr spearmanr time
0 1.180757 0.992539 0.102852 0.084205 00:01
overall MSE: 0.9925
Average Pearson: 0.1617 
------fold2------
lr in training is 0.003
epoch train_loss valid_loss pearsonr spearmanr time
0 1.159264 0.987170 0.119972 0.098912 00:01
overall MSE: 0.9872
Average Pearson: 0.2364 
------fold3------
lr in training is 0.003
epoch train_loss valid_loss pearsonr spearmanr time
0 1.184666 1.001829 0.077155 0.047876 00:01
overall MSE: 1.0018
Average Pearson: 0.1415 
------fold4------
lr in training is 0.003
epoch train_loss valid_loss pearsonr spearmanr time
0 1.178444 0.992547 0.109969 0.100576 00:01
overall MSE: 0.9925
Average Pearson: 0.2014 
metrics
fold mse pearson_avg
0 0 0.997911 0.163423
1 1 0.992539 0.161654
2 2 0.987170 0.236363
3 3 1.001829 0.141464
4 4 0.992547 0.201375
metrics.pearson_avg.mean()
0.18085578818910147
target = df[target_col]
_,_,corr = score_each(target,oof)
overall MSE: 0.9944
Average Pearson: 0.1809 
corr
Pearson
0 -0.183429
1 -0.178564
2 -0.225202
3 -0.117838
4 -0.153463
... ...
385 0.109792
386 0.269238
387 0.079601
388 0.063310
389 0.343237

390 rows × 1 columns

DL Predict


source

predict_dl

 predict_dl (df, feat_col, target_col, model, model_pth)

Predict dataframe given a deep learning model

Details
df
feat_col
target_col
model model architecture
model_pth only name, not with .pth
test = df.loc[split0[1]]
pred = predict_dl(test.head(3),
                  feat_col,
                  target_col, 
                  model,'test')
pred
-5P -5G -5A -5C -5S -5T -5V -5I -5L -5M -5F -5Y -5W -5H -5K -5R -5Q -5N -5D -5E -5s -5t -5y -4P -4G -4A -4C -4S -4T -4V -4I -4L -4M -4F -4Y -4W -4H -4K -4R -4Q -4N -4D -4E -4s -4t -4y -3P -3G -3A -3C -3S -3T -3V -3I -3L -3M -3F -3Y -3W -3H -3K -3R -3Q -3N -3D -3E -3s -3t -3y -2P -2G -2A -2C -2S -2T -2V -2I -2L -2M -2F -2Y -2W -2H -2K -2R -2Q -2N -2D -2E -2s -2t -2y -1P -1G -1A -1C -1S -1T -1V -1I -1L -1M -1F -1Y -1W -1H -1K -1R -1Q -1N -1D -1E -1s -1t -1y 1P 1G 1A 1C 1S 1T 1V 1I 1L 1M 1F 1Y 1W 1H 1K 1R 1Q 1N 1D 1E 1s 1t 1y 2P 2G 2A 2C 2S 2T 2V 2I 2L 2M 2F 2Y 2W 2H 2K 2R 2Q 2N 2D 2E 2s 2t 2y 3P 3G 3A 3C 3S 3T 3V 3I 3L 3M 3F 3Y 3W 3H 3K 3R 3Q 3N 3D 3E 3s 3t 3y 4P 4G 4A 4C 4S 4T 4V 4I 4L 4M 4F 4Y 4W 4H 4K 4R 4Q 4N 4D 4E 4s 4t 4y 0s 0t 0y
3 -1.124083 -1.333018 0.370999 -1.482608 1.065671 2.435333 -1.181065 0.243360 0.541074 0.526481 0.789112 -1.935882 1.474911 0.908850 -2.656013 0.462956 -0.006826 1.895988 0.339413 0.476922 -0.983320 1.270476 1.475200 -1.669044 -1.078790 -1.765836 0.821250 0.570236 -0.302717 1.042980 -1.587289 0.182884 2.623421 -2.359888 -0.646681 1.590253 2.212846 -2.315615 -0.418003 1.428890 0.929068 1.008114 -0.425482 -2.041824 2.350778 -0.230695 0.340297 0.432018 0.793356 1.710751 -0.199124 0.316310 -2.101466 -1.823912 -1.321935 -1.437125 0.130883 2.243388 1.785522 -1.228163 -1.860953 -2.221823 -0.398452 -1.200318 1.220733 -0.437580 -1.395456 -1.520077 0.929825 -0.452956 -2.240156 1.864775 2.193667 -0.092991 -0.991873 -0.043211 -0.026852 -1.612026 0.882562 -1.191970 0.247461 -1.808025 0.248498 -2.131180 -2.775856 -2.635972 -0.260469 1.875052 -0.374094 -1.802452 1.710730 0.500554 -2.568784 -1.137813 1.847648 -1.176928 -0.784231 -0.335855 -0.998383 -1.808661 0.702443 -0.933349 -1.970836 1.824810 -0.550751 -1.536322 -0.277769 -0.246322 0.715673 -0.243507 -0.004827 0.491623 -0.027491 -0.621791 1.520924 -1.725467 -0.762519 0.823468 -0.832005 -1.385254 1.463969 -1.065097 -0.268777 0.191179 -0.256498 1.285461 1.410629 -0.480923 0.052641 -1.333732 -0.808122 -0.641196 0.846606 0.054078 -2.322969 -1.790822 0.040727 0.311768 2.213273 1.011133 1.735011 2.706975 2.267068 1.176308 -1.155246 0.358741 -0.040418 0.549987 1.847084 -0.215823 -0.858711 -1.608082 -1.274341 -0.575624 0.875942 0.312031 -0.394171 -1.378174 -1.457255 0.548059 -0.840706 2.288101 0.837432 0.640592 1.597600 2.678149 2.460815 -3.140487 2.607033 0.160671 1.281283 0.391561 0.285112 0.313980 1.758898 -0.577085 -1.576527 -0.883964 -0.221910 -0.472430 0.627727 1.816311 -1.552552 -2.032012 1.988275 1.918780 -1.275063 -1.375782 0.033672 -1.072162 -0.100851 -0.315402 -0.936294 -0.149866 1.335872 1.107367 0.755761 0.614412 -0.339984 1.296790 0.086605 -0.325200 0.626342 0.139817 -1.742222 0.486119 0.835788 1.808517 -0.691114 1.019787
8 -1.082794 -1.315036 0.377679 -1.303769 1.028240 2.277051 -1.140310 0.201248 0.481489 0.407088 0.760801 -1.863295 1.475667 0.822474 -2.524725 0.393658 -0.127901 1.751877 0.297822 0.382287 -0.939208 1.297959 1.379947 -1.500982 -1.073636 -1.720251 0.947168 0.489036 -0.398783 1.048175 -1.578853 0.202615 2.475548 -2.330920 -0.625377 1.575310 2.092734 -2.183469 -0.456739 1.414051 0.921833 0.912542 -0.522701 -1.903706 2.279727 -0.167850 0.331005 0.305503 0.690275 1.713022 -0.114235 0.247336 -1.971480 -1.755766 -1.216914 -1.356737 0.120693 2.188673 1.620649 -1.219324 -1.883983 -2.129232 -0.349166 -1.165124 1.202160 -0.446142 -1.265840 -1.397379 0.963779 -0.340953 -2.203007 1.697120 2.142285 -0.070298 -0.790852 -0.047399 -0.163245 -1.561266 0.780174 -1.106945 0.288282 -1.763729 0.198121 -1.961371 -2.751048 -2.657668 -0.304017 1.781997 -0.442557 -1.629253 1.846559 0.596902 -2.496807 -1.063768 1.810785 -1.048127 -0.712356 -0.200570 -0.918905 -1.857498 0.661085 -0.874454 -1.888536 1.837306 -0.410973 -1.405960 -0.243723 -0.118655 0.758936 -0.238423 0.035166 0.525149 0.007085 -0.653448 1.576913 -1.723523 -0.764785 0.659238 -0.681803 -1.357008 1.411377 -1.068875 -0.276262 0.145424 -0.154642 1.320319 1.450584 -0.400749 0.072105 -1.270154 -0.710633 -0.639188 0.825336 0.027273 -2.346619 -1.877316 -0.009294 0.250529 2.121087 0.856824 1.638516 2.668117 2.182310 1.150948 -1.064407 0.255830 -0.072337 0.510535 1.832212 -0.094357 -0.771008 -1.546097 -1.273632 -0.578555 0.783154 0.274340 -0.444182 -1.413612 -1.433344 0.548449 -0.811656 2.192537 0.838117 0.546955 1.536627 2.484802 2.363493 -3.035698 2.486802 0.129263 1.323699 0.465493 0.312922 0.323904 1.663263 -0.602994 -1.513537 -0.794236 -0.261255 -0.463526 0.499759 1.750798 -1.547009 -1.931046 1.905437 1.763401 -1.184554 -1.306713 0.010132 -0.859298 -0.293271 -0.446997 -0.878223 -0.176436 1.229133 1.027897 0.718036 0.600043 -0.344983 1.210207 0.189907 -0.176607 0.637552 0.016095 -1.645154 0.400391 0.861786 1.831862 -0.617930 0.849381
10 -1.123702 -1.333739 0.366915 -1.481333 1.064120 2.436097 -1.182564 0.240058 0.532565 0.530294 0.786545 -1.937065 1.474733 0.902049 -2.661757 0.461727 -0.000827 1.899083 0.339538 0.481326 -0.989064 1.271149 1.472290 -1.670950 -1.078952 -1.766250 0.818193 0.571061 -0.305479 1.043114 -1.586039 0.186335 2.619740 -2.360800 -0.646138 1.595065 2.218405 -2.316889 -0.417599 1.419278 0.933096 1.008337 -0.425961 -2.040859 2.348625 -0.230205 0.342206 0.438278 0.790901 1.707645 -0.197786 0.317424 -2.102586 -1.823015 -1.326136 -1.441910 0.130691 2.245629 1.784710 -1.232561 -1.864752 -2.223518 -0.395831 -1.202127 1.219322 -0.429746 -1.399199 -1.524138 0.935269 -0.459730 -2.240560 1.856851 2.189566 -0.092959 -0.995058 -0.048542 -0.032541 -1.609916 0.883162 -1.187196 0.243556 -1.807652 0.244450 -2.132658 -2.778566 -2.639370 -0.264530 1.877488 -0.369802 -1.802853 1.705964 0.503266 -2.575821 -1.141595 1.846794 -1.179568 -0.783157 -0.332587 -1.004201 -1.810084 0.700647 -0.936017 -1.972526 1.827652 -0.550069 -1.534301 -0.285710 -0.250438 0.710358 -0.247381 -0.005479 0.497706 -0.025079 -0.622655 1.515863 -1.725029 -0.768893 0.822101 -0.829749 -1.389199 1.458424 -1.061635 -0.267768 0.189667 -0.251762 1.280175 1.409162 -0.482356 0.053705 -1.336888 -0.803956 -0.637138 0.844385 0.056590 -2.324191 -1.794218 0.045212 0.311169 2.211957 1.013297 1.733366 2.711608 2.267383 1.180971 -1.155435 0.362334 -0.041591 0.555404 1.850645 -0.215913 -0.854904 -1.611007 -1.278233 -0.572426 0.872702 0.310544 -0.392968 -1.380815 -1.458252 0.546254 -0.840335 2.291621 0.837838 0.639362 1.597017 2.676531 2.464077 -3.141581 2.609081 0.159573 1.282543 0.395703 0.290928 0.314056 1.763387 -0.576015 -1.577078 -0.882917 -0.221786 -0.469884 0.630226 1.810092 -1.554211 -2.031102 1.985721 1.916608 -1.274601 -1.378692 0.030345 -1.073612 -0.098157 -0.311278 -0.941815 -0.152202 1.338775 1.102758 0.755549 0.615726 -0.340387 1.296974 0.080942 -0.324911 0.624071 0.138197 -1.744510 0.489025 0.833121 1.808313 -0.695774 1.022946
_,_,corr = score_each(test[target_col].head(3),pred)
overall MSE: 3.1266
Average Pearson: 0.1593