from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from tqdm import tqdm
from katlas.core import *
import pandas as pd, numpy as np,seaborn as sns
Kmeans motifs
Kmeans
Data
# human = pd.read_parquet('raw/human_phosphoproteome.parquet')
# df_grouped = pd.read_parquet('raw/combine_source_grouped.parquet')
= Data.get_human_site()
human = Data.get_ks_dataset() df_grouped
= pd.concat([human,df_grouped]) all_site
sum() all_site.sub_site.isna().
np.int64(0)
= all_site.drop_duplicates('sub_site') all_site
all_site.shape
(131843, 21)
# all_site = all_site[['sub_site','site_seq']].drop_duplicates('sub_site')
One-hot encode
= onehot_encode(all_site['site_seq']) onehot
CPU times: user 2.21 s, sys: 1.14 s, total: 3.35 s
Wall time: 3.35 s
onehot.head()
-20A | -20C | -20D | -20E | -20F | -20G | -20H | -20I | -20K | -20L | ... | 20R | 20S | 20T | 20V | 20W | 20Y | 20_ | 20s | 20t | 20y | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
1 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
3 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
5 rows × 967 columns
Elbow method
all_site.shape
(131843, 19)
set(rc={"figure.dpi":300, 'savefig.dpi':300})
sns.'notebook')
sns.set_context("ticks") sns.set_style(
get_clusters_elbow(onehot)
CPU times: user 23min 20s, sys: 46.3 s, total: 24min 6s
Wall time: 9min 50s
Kmeans
If using RAPIDS
# # pip install --extra-index-url=https://pypi.nvidia.com \"cudf-cu12==25.2.*\" \"cuml-cu12==25.2.*\"
# %load_ext cudf.pandas
# import numpy as np, pandas as pd
# from cuml import KMeans
# import matplotlib.pyplot as plt
# import seaborn as sns
# from tqdm import tqdm
# from katlas.core import *
# from katlas.plot import *
def kmeans(onehot,n=2,seed=42):
= KMeans(n_clusters=n, random_state=seed,n_init='auto')
kmeans return kmeans.fit_predict(onehot)
=[50,150,300]
ncluster=[42,2025,28] seeds
'test_id']=1 all_site[
'test_id') get_cluster_pssms(all_site,
100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00, 2.03s/it]
-20P | -20G | -20A | -20C | -20S | -20T | -20V | -20I | -20L | -20M | ... | 20H | 20K | 20R | 20Q | 20N | 20D | 20E | 20pS | 20pT | 20pY | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 0.07984 | 0.06669 | 0.07291 | 0.01333 | 0.06488 | 0.04065 | 0.0505 | 0.0346 | 0.0812 | 0.01958 | ... | 0.02192 | 0.06782 | 0.06339 | 0.04731 | 0.0334 | 0.05239 | 0.0798 | 0.04184 | 0.01385 | 0.00508 |
1 rows × 943 columns
=[]
pssmsfor seed in seeds:
print('seed',seed)
for n in ncluster:
= f'cluster{n}_seed{seed}'
colname print(colname)
= kmeans(onehot,n=n,seed=seed)
all_site[colname] = get_cluster_pssms(all_site,colname) # count threshold 10, no threshold for non-nan
pssm_df =colname+'_'+pssm_df.index.astype(str)
pssm_df.index pssms.append(pssm_df)
seed 42
cluster50_seed42
100%|████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 17.46it/s]
cluster150_seed42
100%|██████████████████████████████████████████████████████████████████████████████| 150/150 [00:04<00:00, 34.35it/s]
cluster300_seed42
100%|██████████████████████████████████████████████████████████████████████████████| 300/300 [00:06<00:00, 45.00it/s]
seed 2025
cluster50_seed2025
100%|████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 17.37it/s]
cluster150_seed2025
100%|██████████████████████████████████████████████████████████████████████████████| 150/150 [00:04<00:00, 34.96it/s]
cluster300_seed2025
100%|██████████████████████████████████████████████████████████████████████████████| 300/300 [00:07<00:00, 42.13it/s]
seed 28
cluster50_seed28
100%|████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 17.92it/s]
cluster150_seed28
100%|██████████████████████████████████████████████████████████████████████████████| 150/150 [00:04<00:00, 33.37it/s]
cluster300_seed28
100%|██████████████████████████████████████████████████████████████████████████████| 300/300 [00:07<00:00, 41.01it/s]
= pd.concat(pssms,axis=0) pssms
Save:
# pssms.to_parquet('raw/kmeans.parquet')
# all_site.to_parquet('raw/kmeans_site.parquet',index=False)
# pssms = pd.read_parquet('raw/kmeans.parquet')
# all_site=pd.read_parquet('raw/kmeans_site.parquet')
Hierarchical clustering
Hierarchical clustering of all pssms
from scipy.cluster.hierarchy import linkage, fcluster,dendrogram
import pandas as pd
from katlas.core import *
= pd.read_parquet('raw/kmeans.parquet') pssms
= get_pssm_seq_labels(pssms) labels
= get_Z(pssms) Z
=labels)
plot_dendrogram(Z,labels'dendrogram.pdf')
save_pdf( plt.close()
Visualize and find out the color threshold that works.
After determine the color threshold, use it to cut the tree.
Visualize some logos
'cluster300_seed28_217','cluster50_seed42_40') plot_logos(pssms,
Cut trees to merge similar pssms
= fcluster(Z, t=0.03, criterion='distance')
labels # pssm_df['cluster'] = labels
len(labels)
1500
10] # always start from 1 np.unique(labels)[:
array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int32)
Expand the cluster into a single column
= ['sub_site', 'site_seq']
id_vars = [col for col in all_site.columns if col.startswith('cluster')] value_vars
= pd.melt(all_site, id_vars=id_vars, value_vars=value_vars, var_name='cluster_info', value_name='cluster') all_site_long
'cluster_id']=all_site_long['cluster_info'] + '_' + all_site_long['cluster'].astype(str) all_site_long[
all_site_long.head()
sub_site | site_seq | cluster_info | cluster | cluster_id | |
---|---|---|---|---|---|
0 | A0A024R4G9_S20 | _MTVLEAVLEIQAITGSRLLsMVPGPARPPGSCWDPTQCTR | cluster50_seed42 | 33 | cluster50_seed42_33 |
1 | A0A075B6Q4_S24 | QKSENEDDSEWEDVDDEKGDsNDDYDSAGLLsDEDCMSVPG | cluster50_seed42 | 41 | cluster50_seed42_41 |
2 | A0A075B6Q4_S35 | EDVDDEKGDsNDDYDSAGLLsDEDCMSVPGKTHRAIADHLF | cluster50_seed42 | 30 | cluster50_seed42_30 |
3 | A0A075B6Q4_S57 | EDCMSVPGKTHRAIADHLFWsEETKSRFTEYsMTssVMRRN | cluster50_seed42 | 30 | cluster50_seed42_30 |
4 | A0A075B6Q4_S68 | RAIADHLFWsEETKSRFTEYsMTssVMRRNEQLTLHDERFE | cluster50_seed42 | 11 | cluster50_seed42_11 |
# all_site_long.to_parquet('raw/kmeans_site_long.parquet',index=False)
Map merged cluster
len(labels)
1500
= pd.Series(labels,index=pssms.index) cluster_map
cluster_map.sort_values()
cluster50_seed42_28 1
cluster300_seed42_272 1
cluster150_seed2025_137 1
cluster150_seed28_104 1
cluster150_seed42_147 1
...
cluster150_seed2025_134 783
cluster50_seed2025_8 783
cluster150_seed42_101 783
cluster50_seed28_48 783
cluster150_seed28_131 783
Length: 1500, dtype: int32
For those unmapped cluster_ID, we assign them zero value:
# not all cluster_id have a corresponding for new cluster ID, as they could be filtered out
'cluster_new'] = all_site_long.cluster_id.map(lambda x: cluster_map.get(x, 0)) #0 is unmapped all_site_long[
# all_site_long.to_parquet('raw/kmeans_site_long_cluster_new.parquet',index=False)
Get new cluster motifs
= get_cluster_pssms(all_site_long,
pssms2 'cluster_new')
100%|██████████████████████████████████████████████████████████████████████████████| 783/783 [00:22<00:00, 35.32it/s]
pssms2.shape
(783, 943)
# pssms2 = pssms2.drop(index=0) # as 0 represents unmapped
# pssms2.sort_index().to_parquet('out/all_site_pssms.parquet')
=pd.read_parquet('out/all_site_pssms.parquet') pssms2
Hierarchical clustering of merged pssms
= get_Z(pssms2) Z
all_site_long.cluster_new
0 391
1 548
2 495
3 495
4 594
...
1186582 689
1186583 348
1186584 672
1186585 672
1186586 701
Name: cluster_new, Length: 1186587, dtype: int32
= all_site_long.drop_duplicates(subset=["cluster_new","sub_site"])["cluster_new"].value_counts() count_map
count_map
cluster_new
741 12595
473 10903
701 10566
679 8251
317 7594
...
35 21
10 21
754 20
4 17
7 16
Name: count, Length: 783, dtype: int64
= get_pssm_seq_labels(pssms2) labels
4] labels[:
['741 (n=12,595): ....................t*....................',
'473 (n=10,903): ....................s*....................',
'701 (n=10,566): ....................y*....................',
'679 (n=8,251): ....................t*s[s/P]..................']
=7,labels=labels,color_thr=0.07)
plot_dendrogram(Z,interval'raw/dendrogram.pdf')
save_pdf( plt.close()
Onehot of cluster number
= pd.read_parquet('raw/kmeans_site_long.parquet') all_site_long
'sub_site_seq'] = all_site_long['sub_site']+'_'+all_site_long['site_seq'] all_site_long[
= pd.crosstab(all_site_long['sub_site_seq'], all_site_long['cluster_new']) all_site_onehot
# greater than 0 to be True and convert to int
= all_site_onehot.gt(0).astype(int) all_site_onehot
max() all_site_onehot.
cluster_new
1 1
2 1
3 1
4 1
5 1
..
779 1
780 1
781 1
782 1
783 1
Length: 783, dtype: int64
# remove 0 as it is unassigned for cut tree
# all_site_onehot = all_site_onehot.drop(columns=0)
all_site_onehot.columns
Index([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
...
774, 775, 776, 777, 778, 779, 780, 781, 782, 783],
dtype='int32', name='cluster_new', length=783)
sum().sort_values() all_site_onehot.
cluster_new
7 16
4 17
754 20
10 21
35 21
...
317 7594
679 8251
701 10566
473 10903
741 12595
Length: 783, dtype: int64
# for save in parquet, needs column type to be str
= all_site_onehot.columns.astype(str) all_site_onehot.columns
# all_site_onehot.to_parquet('out/all_site_cluster_onehot.parquet')
# all_site_onehot=pd.read_parquet('out/all_site_cluster_onehot.parquet')
all_site_onehot
cluster_new | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ... | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
sub_site_seq | |||||||||||||||||||||
A0A024R4G9_S20__MTVLEAVLEIQAITGSRLLsMVPGPARPPGSCWDPTQCTR | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S24_QKSENEDDSEWEDVDDEKGDsNDDYDSAGLLsDEDCMSVPG | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S35_EDVDDEKGDsNDDYDSAGLLsDEDCMSVPGKTHRAIADHLF | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S57_EDCMSVPGKTHRAIADHLFWsEETKSRFTEYsMTssVMRRN | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S68_RAIADHLFWsEETKSRFTEYsMTssVMRRNEQLTLHDERFE | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
V9GYY5_S132_RLLGLtPPEGGAGDRsEEEAsstEKPtKALPRKSRDPLLSQ | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
V9GYY5_S133_LLGLtPPEGGAGDRsEEEAsstEKPtKALPRKSRDPLLSQR | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
V9GYY5_T117_TVTVTTISDLDLsGARLLGLtPPEGGAGDRsEEEAsstEKP | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
V9GYY5_T134_LGLtPPEGGAGDRsEEEAsstEKPtKALPRKSRDPLLSQRI | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
V9GYY5_T138_PPEGGAGDRsEEEAsstEKPtKALPRKSRDPLLSQRISSLT | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
131843 rows × 783 columns
all_site_onehot.head()
cluster_new | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ... | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
sub_site_seq | |||||||||||||||||||||
A0A024R4G9_S20__MTVLEAVLEIQAITGSRLLsMVPGPARPPGSCWDPTQCTR | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S24_QKSENEDDSEWEDVDDEKGDsNDDYDSAGLLsDEDCMSVPG | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S35_EDVDDEKGDsNDDYDSAGLLsDEDCMSVPGKTHRAIADHLF | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S57_EDCMSVPGKTHRAIADHLFWsEETKSRFTEYsMTssVMRRN | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
A0A075B6Q4_S68_RAIADHLFWsEETKSRFTEYsMTssVMRRNEQLTLHDERFE | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 783 columns
# switch back to int for downstream
= all_site_onehot.columns.astype(int) all_site_onehot.columns
Special motifs
def plot_logos(pssms_df,*idxs):
"Plot logos of a dataframe with flattened PSSMs with index ad IDs."
for idx in idxs:
= recover_pssm(pssms_df.loc[idx])
pssm =f'Motif {idx}',figsize=(14,1))
plot_logo(pssm,title
plt.show() plt.close()
Customized
748,669,668) plot_logos(pssms2,
Most common
= all_site_onehot.sum().sort_values(ascending=False).head(20).index idxs
*idxs) plot_logos(pssms2,
Entropy per position
= pssms2.apply(lambda r: entropy_flat(r),axis=1) pssms2_entropy
# remove 0 and focus on those neighboring residues
= pssms2_entropy.drop(columns=0) pssms2_entropy
=pssms2_entropy.min(1).sort_values().head(10).index idxs
Motif with lowest entropies:
*idxs) plot_logos(pssms2,
Low sum entropy (similar to median, but remove terminal )
Motifs with low median entropies:
=pssms2_entropy.sum(1).sort_values().head(80).index idxs
The first one is mostly Zinc finger protein
*idxs) plot_logos(pssms2,
=pssms2_entropy.sum(1).sort_values().head(10).index idxs
=(pssms2==0).sum(1).sort_values(ascending=False).index idxs
C-terminal motifs
= pssms2.apply(lambda r: (recover_pssm(r).loc[:,1:].sum()==0).sum() , axis=1) zeros_right
= zeros_right.sort_values(ascending=False).head(10).index idxs
*idxs) plot_logos(pssms2,
N-Terminal motifs:
= pssms2.apply(lambda r: (recover_pssm(r).loc[:,:0].sum()==0).sum() , axis=1) zeros_left
= zeros_left.sort_values(ascending=False).head(10).index idxs
# zeros = pssms2.apply(lambda r: (recover_pssm(r).sum()==0).sum() , axis=1)
# idxs = zeros.sort_values(ascending=False).head(10).index
*idxs) plot_logos(pssms2,