seed_everything()
Train DL
Setup
Utils
seed_everything
seed_everything (seed=123)
def_device
'cpu'
Load Data
# read training data
= pd.read_parquet('https://github.com/sky1ove/katlas_raw/raw/refs/heads/main/nbs/raw/combine_t5_kd.parquet').reset_index()
df
# read data contains info for split
= Data.get_kinase_info().query('pseudo!="1"') # get non-pseudo kinase
info_df
# merge info with training data
= df[['kinase']].merge(info_df)
info
info.head()
# splits
= get_splits(info,stratified='group')
splits = splits[0]
split0
# column name of feature and target
= df.columns[df.columns.str.startswith('T5_')]
feat_col = df.columns[~df.columns.isin(feat_col)][1:] target_col
StratifiedKFold(n_splits=5, random_state=123, shuffle=True)
# kinase group in train set: 9
# kinase group in test set: 9
---------------------------
# kinase in train set: 312
---------------------------
# kinase in test set: 78
---------------------------
test set: ['EPHA3' 'FES' 'FLT3' 'FYN' 'EPHB1' 'EPHB3' 'FER' 'EPHB4' 'FLT4' 'FGFR1' 'EPHA5' 'TEK' 'DDR2' 'ZAP70' 'LIMK1' 'ULK3' 'JAK1' 'WEE1' 'TESK1' 'MAP2K3' 'AMPKA2' 'ATM' 'CAMK1D' 'CAMK2D' 'CAMK4' 'CAMKK1'
'CK1D' 'CK1E' 'DYRK2' 'DYRK4' 'HGK' 'IKKE' 'JNK2' 'JNK3' 'KHS1' 'MAPKAPK5' 'MEK2' 'MSK2' 'NDR1' 'NEK6' 'NEK9' 'NIM1' 'NLK' 'OSR1' 'P38A' 'P38B' 'P90RSK' 'PAK1' 'PERK' 'PKCH' 'PKCI' 'PKN1' 'ROCK2'
'RSK2' 'SIK' 'STLK3' 'TAK1' 'TSSK1' 'ALPHAK3' 'BMPR2' 'CDK10' 'CDK13' 'CDK14' 'CDKL5' 'GCN2' 'GRK4' 'IRE1' 'KHS2' 'MASTL' 'MLK4' 'MNK1' 'MRCKA' 'PRPK' 'QSK' 'SMMLCK' 'SSTK' 'ULK2' 'VRK1']
Dataset
GeneralDataset
GeneralDataset (df, feat_col, target_col=None)
A general dataset that can be applied to any dataframe
Type | Default | Details | |
---|---|---|---|
df | a dataframe of values | ||
feat_col | feature columns | ||
target_col | NoneType | None | Will return test set for prediction if target col is None |
# dataset
= GeneralDataset(df,feat_col,target_col) ds
len(ds)
390
= DataLoader(ds, batch_size=64, shuffle=True) dl
get_sampler
get_sampler (info, col)
For imbalanced data, get higher weights for less-represented samples
= get_sampler(info,'subfamily') sampler
# dataloader
= DataLoader(ds, batch_size=64, sampler=sampler) dl
= next(iter(dl))
xb,yb
xb.shape,yb.shape
(torch.Size([64, 1024]), torch.Size([64, 210]))
Models
MLP
MLP_1
MLP_1 (num_features, num_targets, hidden_units=[512, 218], dp=0.2)
= len(feat_col)
n_feature = len(target_col) n_target
= MLP_1(n_feature, n_target) model
model(xb)
tensor([[-0.1115, -0.3755, -0.3818, ..., -0.1483, -0.0387, -0.1111],
[ 0.8555, 0.9352, -0.9642, ..., -0.4723, 0.7757, -0.0121],
[ 0.3422, 0.3537, -0.1441, ..., 0.5467, -0.4535, 0.2103],
...,
[-0.4287, 0.6751, 0.1797, ..., 0.0192, 0.0692, -0.0573],
[-0.0206, -0.1953, 0.7445, ..., -0.2206, -0.1188, 0.4579],
[ 0.2342, -0.0243, 0.4630, ..., 0.8393, 0.5747, -0.6881]], grad_fn=<AddmmBackward0>)
CNN1D
Version 1
CNN1D_1
CNN1D_1 (num_features, num_targets)
Same as nn.Module
, but no need for subclasses to call super().__init__
Details | |
---|---|
num_features | this does not matter, just for format |
num_targets |
= CNN1D_1(n_feature, n_target) model
model(xb)
tensor([[ 0.0193, 0.0690, 0.0138, ..., -0.0428, -0.0026, 0.0840],
[ 0.0203, 0.0693, 0.0136, ..., -0.0422, -0.0023, 0.0846],
[ 0.0198, 0.0703, 0.0148, ..., -0.0424, -0.0029, 0.0839],
...,
[ 0.0197, 0.0694, 0.0147, ..., -0.0429, -0.0019, 0.0841],
[ 0.0193, 0.0687, 0.0146, ..., -0.0429, -0.0017, 0.0843],
[ 0.0191, 0.0692, 0.0148, ..., -0.0425, -0.0028, 0.0834]], grad_fn=<AddmmBackward0>)
Version 2
init_weights
init_weights (m, leaky=0.0)
Initiate any Conv layer with Kaiming norm.
lin_wn
lin_wn (ni, nf, dp=0.1, act=<class 'torch.nn.modules.activation.SiLU'>)
Weight norm of linear.
conv_wn
conv_wn (ni, nf, ks=3, stride=1, padding=1, dp=0.1, act=<class 'torch.nn.modules.activation.ReLU'>)
Weight norm of conv.
CNN1D_2
CNN1D_2 (ni, nf, amp_scale=16)
*Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to
, etc.
.. note:: As per the example above, an __init__()
call to the parent class must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool*
= CNN1D_2(n_feature,n_target).apply(init_weights) model
model(xb)
tensor([[-0.5740, 0.0151, -0.0819, ..., 0.2636, 0.3405, -0.1404],
[-0.6800, 0.5530, -0.0958, ..., -0.3752, -0.6124, 0.7171],
[ 0.4427, -0.3204, -0.3243, ..., -0.2290, 0.1070, 0.1504],
...,
[-0.3660, -0.2667, -0.6036, ..., -0.3130, 0.5462, -0.0055],
[ 0.4511, 0.6824, 0.8659, ..., -0.0171, 0.2362, -0.3475],
[-0.0746, -0.1699, 0.6895, ..., 1.1522, -0.3472, 0.6422]], grad_fn=<AddmmBackward0>)
DL Trainer
train_dl
train_dl (df, feat_col, target_col, split, model_func, n_epoch=4, bs=32, lr=0.01, loss=<function mse>, save=None, sampler=None, lr_find=False)
A DL trainer.
Type | Default | Details | |
---|---|---|---|
df | |||
feat_col | |||
target_col | |||
split | tuple of numpy array for split index | ||
model_func | function to get pytorch model | ||
n_epoch | int | 4 | number of epochs |
bs | int | 32 | batch size |
lr | float | 0.01 | will be useless if lr_find is True |
loss | function | mse | loss function |
save | NoneType | None | models/{save}.pth |
sampler | NoneType | None | |
lr_find | bool | False | if true, will use lr from lr_find |
def get_model():
return CNN1D_2(n_feature, n_target)
= train_dl(df,
target, pred
feat_col,
target_col,
split0,
get_model,=1,
n_epoch= 1e-2,
lr = 'test') save
lr in training is 0.01
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 2.194746 | 2.176958 | -0.104578 | -0.053141 | 00:04 |
score_each(target,pred)
overall MSE: 2.1770
Average Pearson: 0.2149
(2.176958,
0.21488270397776432,
Pearson
3 -0.442855
8 -0.490345
10 -0.401885
19 -0.428557
24 -0.383956
.. ...
359 -0.127000
361 0.005761
366 0.095977
367 0.335805
373 -0.230842
[78 rows x 1 columns])
DL CV
train_dl_cv
train_dl_cv (df, feat_col, target_col, splits, model_func, save:str=None, n_epoch=4, bs=32, lr=0.01, loss=<function mse>, sampler=None, lr_find=False)
Type | Default | Details | |
---|---|---|---|
df | |||
feat_col | |||
target_col | |||
splits | list of tuples | ||
model_func | functions like lambda x: return MLP_1(num_feat, num_target) | ||
save | NoneType | None | models/{save}.pth |
n_epoch | int | 4 | number of epochs |
bs | int | 32 | batch size |
lr | float | 0.01 | will be useless if lr_find is True |
loss | function | mse | loss function |
sampler | NoneType | None | |
lr_find | bool | False | if true, will use lr from lr_find |
def get_model():
return CNN1D_2(n_feature, n_target)
= train_dl_cv(df,feat_col,target_col,splits,get_model,n_epoch=1,lr=3e-3) oof,metrics
------fold0------
lr in training is 0.003
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.165076 | 0.997911 | 0.091948 | 0.058285 | 00:01 |
overall MSE: 0.9979
Average Pearson: 0.1634
------fold1------
lr in training is 0.003
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.180757 | 0.992539 | 0.102852 | 0.084205 | 00:01 |
overall MSE: 0.9925
Average Pearson: 0.1617
------fold2------
lr in training is 0.003
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.159264 | 0.987170 | 0.119972 | 0.098912 | 00:01 |
overall MSE: 0.9872
Average Pearson: 0.2364
------fold3------
lr in training is 0.003
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.184666 | 1.001829 | 0.077155 | 0.047876 | 00:01 |
overall MSE: 1.0018
Average Pearson: 0.1415
------fold4------
lr in training is 0.003
epoch | train_loss | valid_loss | pearsonr | spearmanr | time |
---|---|---|---|---|---|
0 | 1.178444 | 0.992547 | 0.109969 | 0.100576 | 00:01 |
overall MSE: 0.9925
Average Pearson: 0.2014
metrics
fold | mse | pearson_avg | |
---|---|---|---|
0 | 0 | 0.997911 | 0.163423 |
1 | 1 | 0.992539 | 0.161654 |
2 | 2 | 0.987170 | 0.236363 |
3 | 3 | 1.001829 | 0.141464 |
4 | 4 | 0.992547 | 0.201375 |
metrics.pearson_avg.mean()
0.18085578818910147
= df[target_col]
target = score_each(target,oof) _,_,corr
overall MSE: 0.9944
Average Pearson: 0.1809
corr
Pearson | |
---|---|
0 | -0.183429 |
1 | -0.178564 |
2 | -0.225202 |
3 | -0.117838 |
4 | -0.153463 |
... | ... |
385 | 0.109792 |
386 | 0.269238 |
387 | 0.079601 |
388 | 0.063310 |
389 | 0.343237 |
390 rows × 1 columns
DL Predict
predict_dl
predict_dl (df, feat_col, target_col, model, model_pth)
Predict dataframe given a deep learning model
Details | |
---|---|
df | |
feat_col | |
target_col | |
model | model architecture |
model_pth | only name, not with .pth |
= df.loc[split0[1]] test
= predict_dl(test.head(3),
pred
feat_col,
target_col, 'test')
model, pred
-5P | -5G | -5A | -5C | -5S | -5T | -5V | -5I | -5L | -5M | -5F | -5Y | -5W | -5H | -5K | -5R | -5Q | -5N | -5D | -5E | -5s | -5t | -5y | -4P | -4G | -4A | -4C | -4S | -4T | -4V | -4I | -4L | -4M | -4F | -4Y | -4W | -4H | -4K | -4R | -4Q | -4N | -4D | -4E | -4s | -4t | -4y | -3P | -3G | -3A | -3C | -3S | -3T | -3V | -3I | -3L | -3M | -3F | -3Y | -3W | -3H | -3K | -3R | -3Q | -3N | -3D | -3E | -3s | -3t | -3y | -2P | -2G | -2A | -2C | -2S | -2T | -2V | -2I | -2L | -2M | -2F | -2Y | -2W | -2H | -2K | -2R | -2Q | -2N | -2D | -2E | -2s | -2t | -2y | -1P | -1G | -1A | -1C | -1S | -1T | -1V | -1I | -1L | -1M | -1F | -1Y | -1W | -1H | -1K | -1R | -1Q | -1N | -1D | -1E | -1s | -1t | -1y | 1P | 1G | 1A | 1C | 1S | 1T | 1V | 1I | 1L | 1M | 1F | 1Y | 1W | 1H | 1K | 1R | 1Q | 1N | 1D | 1E | 1s | 1t | 1y | 2P | 2G | 2A | 2C | 2S | 2T | 2V | 2I | 2L | 2M | 2F | 2Y | 2W | 2H | 2K | 2R | 2Q | 2N | 2D | 2E | 2s | 2t | 2y | 3P | 3G | 3A | 3C | 3S | 3T | 3V | 3I | 3L | 3M | 3F | 3Y | 3W | 3H | 3K | 3R | 3Q | 3N | 3D | 3E | 3s | 3t | 3y | 4P | 4G | 4A | 4C | 4S | 4T | 4V | 4I | 4L | 4M | 4F | 4Y | 4W | 4H | 4K | 4R | 4Q | 4N | 4D | 4E | 4s | 4t | 4y | 0s | 0t | 0y | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
3 | -1.124083 | -1.333018 | 0.370999 | -1.482608 | 1.065671 | 2.435333 | -1.181065 | 0.243360 | 0.541074 | 0.526481 | 0.789112 | -1.935882 | 1.474911 | 0.908850 | -2.656013 | 0.462956 | -0.006826 | 1.895988 | 0.339413 | 0.476922 | -0.983320 | 1.270476 | 1.475200 | -1.669044 | -1.078790 | -1.765836 | 0.821250 | 0.570236 | -0.302717 | 1.042980 | -1.587289 | 0.182884 | 2.623421 | -2.359888 | -0.646681 | 1.590253 | 2.212846 | -2.315615 | -0.418003 | 1.428890 | 0.929068 | 1.008114 | -0.425482 | -2.041824 | 2.350778 | -0.230695 | 0.340297 | 0.432018 | 0.793356 | 1.710751 | -0.199124 | 0.316310 | -2.101466 | -1.823912 | -1.321935 | -1.437125 | 0.130883 | 2.243388 | 1.785522 | -1.228163 | -1.860953 | -2.221823 | -0.398452 | -1.200318 | 1.220733 | -0.437580 | -1.395456 | -1.520077 | 0.929825 | -0.452956 | -2.240156 | 1.864775 | 2.193667 | -0.092991 | -0.991873 | -0.043211 | -0.026852 | -1.612026 | 0.882562 | -1.191970 | 0.247461 | -1.808025 | 0.248498 | -2.131180 | -2.775856 | -2.635972 | -0.260469 | 1.875052 | -0.374094 | -1.802452 | 1.710730 | 0.500554 | -2.568784 | -1.137813 | 1.847648 | -1.176928 | -0.784231 | -0.335855 | -0.998383 | -1.808661 | 0.702443 | -0.933349 | -1.970836 | 1.824810 | -0.550751 | -1.536322 | -0.277769 | -0.246322 | 0.715673 | -0.243507 | -0.004827 | 0.491623 | -0.027491 | -0.621791 | 1.520924 | -1.725467 | -0.762519 | 0.823468 | -0.832005 | -1.385254 | 1.463969 | -1.065097 | -0.268777 | 0.191179 | -0.256498 | 1.285461 | 1.410629 | -0.480923 | 0.052641 | -1.333732 | -0.808122 | -0.641196 | 0.846606 | 0.054078 | -2.322969 | -1.790822 | 0.040727 | 0.311768 | 2.213273 | 1.011133 | 1.735011 | 2.706975 | 2.267068 | 1.176308 | -1.155246 | 0.358741 | -0.040418 | 0.549987 | 1.847084 | -0.215823 | -0.858711 | -1.608082 | -1.274341 | -0.575624 | 0.875942 | 0.312031 | -0.394171 | -1.378174 | -1.457255 | 0.548059 | -0.840706 | 2.288101 | 0.837432 | 0.640592 | 1.597600 | 2.678149 | 2.460815 | -3.140487 | 2.607033 | 0.160671 | 1.281283 | 0.391561 | 0.285112 | 0.313980 | 1.758898 | -0.577085 | -1.576527 | -0.883964 | -0.221910 | -0.472430 | 0.627727 | 1.816311 | -1.552552 | -2.032012 | 1.988275 | 1.918780 | -1.275063 | -1.375782 | 0.033672 | -1.072162 | -0.100851 | -0.315402 | -0.936294 | -0.149866 | 1.335872 | 1.107367 | 0.755761 | 0.614412 | -0.339984 | 1.296790 | 0.086605 | -0.325200 | 0.626342 | 0.139817 | -1.742222 | 0.486119 | 0.835788 | 1.808517 | -0.691114 | 1.019787 |
8 | -1.082794 | -1.315036 | 0.377679 | -1.303769 | 1.028240 | 2.277051 | -1.140310 | 0.201248 | 0.481489 | 0.407088 | 0.760801 | -1.863295 | 1.475667 | 0.822474 | -2.524725 | 0.393658 | -0.127901 | 1.751877 | 0.297822 | 0.382287 | -0.939208 | 1.297959 | 1.379947 | -1.500982 | -1.073636 | -1.720251 | 0.947168 | 0.489036 | -0.398783 | 1.048175 | -1.578853 | 0.202615 | 2.475548 | -2.330920 | -0.625377 | 1.575310 | 2.092734 | -2.183469 | -0.456739 | 1.414051 | 0.921833 | 0.912542 | -0.522701 | -1.903706 | 2.279727 | -0.167850 | 0.331005 | 0.305503 | 0.690275 | 1.713022 | -0.114235 | 0.247336 | -1.971480 | -1.755766 | -1.216914 | -1.356737 | 0.120693 | 2.188673 | 1.620649 | -1.219324 | -1.883983 | -2.129232 | -0.349166 | -1.165124 | 1.202160 | -0.446142 | -1.265840 | -1.397379 | 0.963779 | -0.340953 | -2.203007 | 1.697120 | 2.142285 | -0.070298 | -0.790852 | -0.047399 | -0.163245 | -1.561266 | 0.780174 | -1.106945 | 0.288282 | -1.763729 | 0.198121 | -1.961371 | -2.751048 | -2.657668 | -0.304017 | 1.781997 | -0.442557 | -1.629253 | 1.846559 | 0.596902 | -2.496807 | -1.063768 | 1.810785 | -1.048127 | -0.712356 | -0.200570 | -0.918905 | -1.857498 | 0.661085 | -0.874454 | -1.888536 | 1.837306 | -0.410973 | -1.405960 | -0.243723 | -0.118655 | 0.758936 | -0.238423 | 0.035166 | 0.525149 | 0.007085 | -0.653448 | 1.576913 | -1.723523 | -0.764785 | 0.659238 | -0.681803 | -1.357008 | 1.411377 | -1.068875 | -0.276262 | 0.145424 | -0.154642 | 1.320319 | 1.450584 | -0.400749 | 0.072105 | -1.270154 | -0.710633 | -0.639188 | 0.825336 | 0.027273 | -2.346619 | -1.877316 | -0.009294 | 0.250529 | 2.121087 | 0.856824 | 1.638516 | 2.668117 | 2.182310 | 1.150948 | -1.064407 | 0.255830 | -0.072337 | 0.510535 | 1.832212 | -0.094357 | -0.771008 | -1.546097 | -1.273632 | -0.578555 | 0.783154 | 0.274340 | -0.444182 | -1.413612 | -1.433344 | 0.548449 | -0.811656 | 2.192537 | 0.838117 | 0.546955 | 1.536627 | 2.484802 | 2.363493 | -3.035698 | 2.486802 | 0.129263 | 1.323699 | 0.465493 | 0.312922 | 0.323904 | 1.663263 | -0.602994 | -1.513537 | -0.794236 | -0.261255 | -0.463526 | 0.499759 | 1.750798 | -1.547009 | -1.931046 | 1.905437 | 1.763401 | -1.184554 | -1.306713 | 0.010132 | -0.859298 | -0.293271 | -0.446997 | -0.878223 | -0.176436 | 1.229133 | 1.027897 | 0.718036 | 0.600043 | -0.344983 | 1.210207 | 0.189907 | -0.176607 | 0.637552 | 0.016095 | -1.645154 | 0.400391 | 0.861786 | 1.831862 | -0.617930 | 0.849381 |
10 | -1.123702 | -1.333739 | 0.366915 | -1.481333 | 1.064120 | 2.436097 | -1.182564 | 0.240058 | 0.532565 | 0.530294 | 0.786545 | -1.937065 | 1.474733 | 0.902049 | -2.661757 | 0.461727 | -0.000827 | 1.899083 | 0.339538 | 0.481326 | -0.989064 | 1.271149 | 1.472290 | -1.670950 | -1.078952 | -1.766250 | 0.818193 | 0.571061 | -0.305479 | 1.043114 | -1.586039 | 0.186335 | 2.619740 | -2.360800 | -0.646138 | 1.595065 | 2.218405 | -2.316889 | -0.417599 | 1.419278 | 0.933096 | 1.008337 | -0.425961 | -2.040859 | 2.348625 | -0.230205 | 0.342206 | 0.438278 | 0.790901 | 1.707645 | -0.197786 | 0.317424 | -2.102586 | -1.823015 | -1.326136 | -1.441910 | 0.130691 | 2.245629 | 1.784710 | -1.232561 | -1.864752 | -2.223518 | -0.395831 | -1.202127 | 1.219322 | -0.429746 | -1.399199 | -1.524138 | 0.935269 | -0.459730 | -2.240560 | 1.856851 | 2.189566 | -0.092959 | -0.995058 | -0.048542 | -0.032541 | -1.609916 | 0.883162 | -1.187196 | 0.243556 | -1.807652 | 0.244450 | -2.132658 | -2.778566 | -2.639370 | -0.264530 | 1.877488 | -0.369802 | -1.802853 | 1.705964 | 0.503266 | -2.575821 | -1.141595 | 1.846794 | -1.179568 | -0.783157 | -0.332587 | -1.004201 | -1.810084 | 0.700647 | -0.936017 | -1.972526 | 1.827652 | -0.550069 | -1.534301 | -0.285710 | -0.250438 | 0.710358 | -0.247381 | -0.005479 | 0.497706 | -0.025079 | -0.622655 | 1.515863 | -1.725029 | -0.768893 | 0.822101 | -0.829749 | -1.389199 | 1.458424 | -1.061635 | -0.267768 | 0.189667 | -0.251762 | 1.280175 | 1.409162 | -0.482356 | 0.053705 | -1.336888 | -0.803956 | -0.637138 | 0.844385 | 0.056590 | -2.324191 | -1.794218 | 0.045212 | 0.311169 | 2.211957 | 1.013297 | 1.733366 | 2.711608 | 2.267383 | 1.180971 | -1.155435 | 0.362334 | -0.041591 | 0.555404 | 1.850645 | -0.215913 | -0.854904 | -1.611007 | -1.278233 | -0.572426 | 0.872702 | 0.310544 | -0.392968 | -1.380815 | -1.458252 | 0.546254 | -0.840335 | 2.291621 | 0.837838 | 0.639362 | 1.597017 | 2.676531 | 2.464077 | -3.141581 | 2.609081 | 0.159573 | 1.282543 | 0.395703 | 0.290928 | 0.314056 | 1.763387 | -0.576015 | -1.577078 | -0.882917 | -0.221786 | -0.469884 | 0.630226 | 1.810092 | -1.554211 | -2.031102 | 1.985721 | 1.916608 | -1.274601 | -1.378692 | 0.030345 | -1.073612 | -0.098157 | -0.311278 | -0.941815 | -0.152202 | 1.338775 | 1.102758 | 0.755549 | 0.615726 | -0.340387 | 1.296974 | 0.080942 | -0.324911 | 0.624071 | 0.138197 | -1.744510 | 0.489025 | 0.833121 | 1.808313 | -0.695774 | 1.022946 |
= score_each(test[target_col].head(3),pred) _,_,corr
overall MSE: 3.1266
Average Pearson: 0.1593