from katlas.core import *
from katlas.plot import *
import pandas as pd,numpy as np
from matplotlib import pyplot as plt
Scoring evaluation
Functions
def split_data(df):
= df[df.kinase_group !='TK'].copy().reset_index(drop=True)
df_st = df[df.kinase_group =='TK'].copy().reset_index(drop=True)
df_tyr
= df_st[df_st.site_seq.str[20]!='y'].copy().reset_index(drop=True)
df_st = df_tyr[df_tyr.site_seq.str[20]=='y'].copy().reset_index(drop=True)
df_tyr return df_st, df_tyr
def split_ref(ref):
= ref[ref.index.isin(ST)].copy()
ref_st = ref[ref.index.isin(TK)].copy()
ref_tk return ref_st,ref_tk
def get_kinase_rank(row_index,df,result):
= df.loc[row_index, 'kinase_protein']
kinase = result.loc[row_index]
scores = scores.sort_values(ascending=False)
ranked = ranked.index.get_loc(kinase) + 1 # +1 to make rank start from 1
rank return rank
def score_data(df,ref,seq_col='site_seq',func=sumup,pct_ref=None):
= split_data(df)
df_st, df_tyr = split_ref(ref)
ref_st,ref_tk
= predict_kinase_df(df_st,seq_col=seq_col,ref=ref_st,func=func)
df_st_result = predict_kinase_df(df_tyr,seq_col=seq_col,ref=ref_tk,func=func)
df_tyr_result
if pct_ref is not None:
= get_pct_df(df_st_result,pct_ref)
df_st_result = get_pct_df(df_tyr_result,pct_ref)
df_tyr_result
'rank'] = df_st_result.index.to_series().apply(lambda idx: get_kinase_rank(idx,df_st,df_st_result))
df_st['rank'] = df_tyr_result.index.to_series().apply(lambda idx: get_kinase_rank(idx,df_tyr,df_tyr_result))
df_tyr[
return df_st,df_tyr,df_st_result,df_tyr_result
def get_ref(df, pssm, seq_col='site_seq', func=sumup, n_chunk=5):
= (len(df) + n_chunk - 1) // n_chunk # ceiling division
chunk_size = []
pct_refs for i in range(n_chunk):
= i * chunk_size
start = min((i + 1) * chunk_size, len(df))
end if start >= end:
break
print(f"Processing chunk {i+1}/{n_chunk} → rows {start}:{end}")
= df.iloc[start:end]
sub_df = predict_kinase_df(
pct_ref
sub_df,=seq_col,
seq_col=pssm,
ref=func
func
)
pct_refs.append(pct_ref)return pd.concat(pct_refs, ignore_index=True)
def top_k_accuracy(df, result, k):
def is_in_top_k(row_index):
= df.loc[row_index, 'kinase_protein']
kinase = result.loc[row_index]
scores = scores.nlargest(k).index
top_k return kinase in top_k
return result.index.to_series().apply(is_in_top_k).mean()
def top_k_accuracy_group(group_indices,df,result, k=10):
def is_correct(row_index):
= df.loc[row_index, 'kinase_protein']
kinase = result.loc[row_index]
scores = scores.nlargest(k).index
top_k return kinase in top_k
return pd.Series(group_indices).apply(is_correct).mean()
def get_topk(df,result,k=10):
= df.groupby('kinase_group').groups
grouped
= {
topk_scores =k)
subfam: top_k_accuracy_group(indices,df,result,kfor subfam, indices in grouped.items()}
= pd.DataFrame.from_dict(topk_scores, orient='index', columns=['top10_accuracy']).sort_values('top10_accuracy', ascending=False)
topk_df return topk_df
def get_group_topk(df_st,df_tyr,df_st_result,df_tyr_result):
= get_topk(df_st,df_st_result)
topk_st = get_topk(df_tyr,df_tyr_result)
topk_tyr
= pd.concat([topk_st,topk_tyr]).sort_values('top10_accuracy',ascending=False)
topk
return topk
def get_results(df_st,df_tyr,df_st_result,df_tyr_result):
= get_group_topk(df_st,df_tyr,df_st_result,df_tyr_result)
topk = get_AUCDF(df_st,'rank')
aucdf_st = get_AUCDF(df_tyr,'rank')
aucdf_tyr
return topk, aucdf_st,aucdf_tyr
set_sns()
Overall
pspa
= pd.read_parquet('raw/overlap_pspa.parquet')
pspa = pspa.index.str.split('_').str[1]
pspa.index pspa.shape
(312, 236)
Data to be scored
= Data.get_ks_dataset() psp
= psp[psp.source.str.contains('PSP')].copy() df
'site_seq_upper'] = df.site_seq.str.upper() df[
# only filter those with tested kinase
= df[df.kinase_protein.isin(pspa.index)].copy().reset_index(drop=True) df
df.shape
(12639, 22)
= df[['kinase_protein','kinase_group']].drop_duplicates().set_index('kinase_protein')['kinase_group']
group_map = group_map[group_map=='TK'].index
TK = group_map[group_map!='TK'].index ST
def prepare_ref(path):
= pd.read_parquet(path)
ref = ref.index.str.split('_').str[1]
ref.index = ref[ref.index.isin(pspa.index)]
ref print(ref.shape)
return ref
cddm:
= prepare_ref('out/CDDM_pssms.parquet')
cddm = prepare_ref('out/CDDM_pssms_upper.parquet')
cddm_upper
= prepare_ref('out/CDDM_pssms_LO.parquet')
cddm_lo = prepare_ref('out/CDDM_pssms_LO_upper.parquet') cddm_lo_upper
(312, 943)
(312, 943)
(312, 943)
(312, 943)
# ref_dict = {
# 'PSPA: PSSM + multiply': (pspa,'site_seq',multiply,False),
# 'CDDM: PSSM + multiply': (cddm, 'site_seq', multiply_23,False),
# 'CDDM: PSSM upper + multiply': (cddm_upper, 'site_seq_upper', multiply_20,False),
# 'CDDM: LO + sum': (cddm_lo, 'site_seq',sumup,False),
# 'CDDM: LO upper + sum': (cddm_lo_upper, 'site_seq_upper',sumup,False),
# }
# human =Data.get_human_site()
# human['site_seq_upper'] = human.site_seq.str.upper()
# human = human.drop_duplicates('site_seq').reset_index()
# for name,(ref,seq_col,func,_) in ref_dict.items():
# print(name)
# pct_ref = get_ref(human,pssm=ref,seq_col=seq_col,func=func)
# pct_ref.to_parquet(f'raw/{name}.parquet')
# print(pct_ref.head())
# # break
= {
ref_dict 'PSPA: PSSM + multiply': (pspa,'site_seq',multiply,False),
'PSPA: PSSM + multiply + pct': (pspa,'site_seq',multiply, True),
'CDDM: PSSM + multiply': (cddm, 'site_seq', multiply_23,False),
'CDDM: PSSM + multiply + pct': (cddm, 'site_seq', multiply_23, True),
'CDDM: PSSM upper + multiply': (cddm_upper, 'site_seq_upper', multiply_20,False),
'CDDM: PSSM upper + multiply + pct': (cddm_upper, 'site_seq_upper', multiply_20,True),
'CDDM: LO + sum': (cddm_lo, 'site_seq',sumup,False),
'CDDM: LO + sum + pct': (cddm_lo, 'site_seq',sumup,True),
'CDDM: LO upper + sum': (cddm_lo_upper, 'site_seq_upper',sumup,False),
'CDDM: LO upper + sum + pct': (cddm_lo_upper, 'site_seq_upper',sumup,True),
}
= []
topk_dfs = pd.DataFrame()
aucdf for name,(ref,seq_col,func,use_pct) in ref_dict.items():
print('---------------------------------------------')
print(name)
= pd.read_parquet(f'raw/{name[:-6]}.parquet') if use_pct else None # -6 to remove + pct from parquet name
pct_ref = score_data(df,ref,seq_col=seq_col,func=func,pct_ref=pct_ref)
df_st,df_tyr,df_st_result,df_tyr_result = get_results(df_st,df_tyr,df_st_result,df_tyr_result)
topk,aucdf_st,aucdf_tyr
topk_dfs.append(topk)'ST'] = aucdf_st
aucdf.loc[name,'Tyr'] = aucdf_tyr aucdf.loc[name,
---------------------------------------------
PSPA: PSSM + multiply
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|███████████████████████████████████████████████████████████████████████████| 213/213 [00:02<00:00, 81.69it/s]
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 428.49it/s]
---------------------------------------------
PSPA: PSSM + multiply + pct
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|███████████████████████████████████████████████████████████████████████████| 213/213 [00:02<00:00, 85.26it/s]
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 423.93it/s]
100%|██████████████████████████████████████████████████████████████████████████| 213/213 [00:00<00:00, 389.86it/s]
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 764.05it/s]
---------------------------------------------
CDDM: PSSM + multiply
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|███████████████████████████████████████████████████████████████████████████| 213/213 [00:04<00:00, 50.58it/s]
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 215.43it/s]
---------------------------------------------
CDDM: PSSM + multiply + pct
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|███████████████████████████████████████████████████████████████████████████| 213/213 [00:04<00:00, 50.47it/s]
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 177.45it/s]
100%|██████████████████████████████████████████████████████████████████████████| 213/213 [00:00<00:00, 373.54it/s]
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 509.36it/s]
---------------------------------------------
CDDM: PSSM upper + multiply
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|███████████████████████████████████████████████████████████████████████████| 213/213 [00:04<00:00, 51.73it/s]
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 228.69it/s]
---------------------------------------------
CDDM: PSSM upper + multiply + pct
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|███████████████████████████████████████████████████████████████████████████| 213/213 [00:04<00:00, 52.55it/s]
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 225.45it/s]
100%|██████████████████████████████████████████████████████████████████████████| 213/213 [00:00<00:00, 388.72it/s]
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 701.91it/s]
---------------------------------------------
CDDM: LO + sum
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
---------------------------------------------
CDDM: LO + sum + pct
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|██████████████████████████████████████████████████████████████████████████| 213/213 [00:00<00:00, 382.47it/s]
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 749.55it/s]
---------------------------------------------
CDDM: LO upper + sum
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
---------------------------------------------
CDDM: LO upper + sum + pct
input dataframe has a length 10603
Preprocessing
Finish preprocessing
Merging reference
Finish merging
input dataframe has a length 1912
Preprocessing
Finish preprocessing
Merging reference
Finish merging
100%|██████████████████████████████████████████████████████████████████████████| 213/213 [00:00<00:00, 378.69it/s]
100%|████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 764.31it/s]
= aucdf.T aucdf
= aucdf.reset_index(names='group') aucdf
'group').T.sort_values('Tyr') aucdf.set_index(
group | ST | Tyr |
---|---|---|
PSPA: PSSM + multiply + pct | 0.788949 | 0.572925 |
PSPA: PSSM + multiply | 0.803919 | 0.599230 |
CDDM: PSSM upper + multiply + pct | 0.845689 | 0.720623 |
CDDM: PSSM + multiply + pct | 0.853593 | 0.767071 |
CDDM: LO upper + sum + pct | 0.890996 | 0.776887 |
CDDM: PSSM upper + multiply | 0.915436 | 0.797867 |
CDDM: LO upper + sum | 0.915436 | 0.797867 |
CDDM: LO + sum + pct | 0.896505 | 0.802644 |
CDDM: LO + sum | 0.927760 | 0.821030 |
CDDM: PSSM + multiply | 0.927760 | 0.821030 |
=aucdf.columns[1:],group='group',rotation=0)
plot_group_bar(aucdf,value_cols'AUCDF') plt.ylabel(
Text(0, 0.5, 'AUCDF')
= pd.concat(topk_dfs,axis=1) topk_df
= ref_dict.keys() topk_df.columns
= topk_df.mean(1).sort_values(ascending=False).index idx_ordered
= topk_df.loc[idx_ordered] topk_df
= topk_df.reset_index(names='group') topk_df
=topk_df.columns[1:],group='group')
plot_group_bar(topk_df,value_cols'Top10 accuracy') plt.ylabel(
Text(0, 0.5, 'Top10 accuracy')
= topk_df.columns[~topk_df.columns.str.contains('pct')][1:] col
=col,group='group')
plot_group_bar(topk_df,value_cols'Top10 accuracy')
plt.ylabel('img.svg') save_svg(