MPNN study

Setup

!pip install torch_geometric
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/63.1 kB ? eta -:--:--     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.1/63.1 kB 3.7 MB/s eta 0:00:00
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.10.10)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (2024.6.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.1.4)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.26.4)
Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (5.9.5)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.2.0)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (2.32.3)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (4.66.5)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (2.4.3)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (24.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (6.1.0)
Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (1.16.0)
Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (4.0.3)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch_geometric) (3.0.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (3.4.0)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2024.8.30)
Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.10/dist-packages (from multidict<7.0,>=4.5->aiohttp->torch_geometric) (4.12.2)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from yarl<2.0,>=1.12.0->aiohttp->torch_geometric) (0.2.0)
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 29.3 MB/s eta 0:00:00
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
import torch
torch.__version__
'2.5.0+cu121'

Now 2.4.0 is available

version = '2.4.0+cu121'
url = f"https://data.pyg.org/whl/torch-{version}.html"
!pip install torch-scatter -f {url}
Looking in links: https://data.pyg.org/whl/torch-2.4.0+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_scatter-2.1.2%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (10.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.9/10.9 MB 85.1 MB/s eta 0:00:00
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt24cu121

Mechanism

g = torch.tensor([[0,0,2,3],[1,3,1,1]],dtype=torch.long)
x = torch.randn(4,3)
x_e = torch.randn(4,2)

def message_func(g,x,x_e):
  src = x[g[0]]
  dst = x[g[1]]
  return src

message = message_func(g,x,x_e)
g = torch.tensor([[0,0,2,3],[1,3,1,1]],dtype=torch.long)
x = torch.randn(4,3)
x_e = torch.randn(4,2)
def message_func(g,x,x_e):
  src = x[g[0]]
  dst = x[g[1]]
  return src
message = message_func(g,x,x_e)
message
tensor([[ 0.2792, -0.3888,  0.9433],
        [ 0.2792, -0.3888,  0.9433],
        [ 1.4204,  0.6912,  0.3983],
        [ 1.2190,  0.1917, -1.3613]])
x
tensor([[ 0.6476, -0.3849,  0.4243],
        [-1.2170,  2.3134, -0.0156],
        [-0.6478, -1.8168, -0.3283],
        [ 1.6240, -0.1349, -0.9557]])
def update_func(x,x_reduce):
  return x+ x_reduce
new_x = update_func(x,x_reduce)
x = torch.tensor([[1],[0.1],[0.01],[0.001]])
x
tensor([[1.0000],
        [0.1000],
        [0.0100],
        [0.0010]])
message = message_func(g,x,x_e)
message
tensor([[1.0000],
        [1.0000],
        [0.0100],
        [0.0010]])
from torch_scatter import scatter
def reduce_func(g,message):
  return scatter(message,g[1],dim=0, reduce='sum')
x_reduce = reduce_func(g,message)
message
tensor([[ 0.2792, -0.3888,  0.9433],
        [ 0.2792, -0.3888,  0.9433],
        [ 1.4204,  0.6912,  0.3983],
        [ 1.2190,  0.1917, -1.3613]])
g[1]
tensor([1, 3, 1, 1])
x_reduce
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 2.9186,  0.4941, -0.0197],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.2792, -0.3888,  0.9433]])
g.shape
torch.Size([2, 4])
def get_in_degrees(g):
  inputs = torch.ones(g.size(1))
  return scatter(inputs, g[1],reduce='sum')
degrees = get_in_degrees(g)
degrees
tensor([0., 3., 0., 1.])
from torch_geometric.nn.conv import MessagePassing
class GCNLayer(MessagePassing):
  def __init__(self):
    super().__init__(aggr='add')
    self.lin = torch.nn.Linear(2,3)

  def forward(self,g,x): # edge before node feature x
    print('input node feature',x)

    x = self.lin(x)
    print('step1, linear trasnform the node features:',x)
    degrees = get_in_degrees(g)
    degrees = 1/(torch.pow(degrees,-0.5)+1e-16)

    src_d = degrees[g[0]]
    dst_d = degrees[g[1]]
    print('src_d:',src_d,'dst_d:',dst_d)
    weight = (src_d*dst_d).unsqueeze(1)
    print('step2 normalized degree:',weight)

    out = self.propagate(g,x=x, weight = weight)
    print('step5', out)

    return out

  def message(self,x_j, weight): # source j, target i
    print('step3 get source features:',x_j,'normalized degree:',weight)
    out = x_j*weight
    print('step4 normalized source features:',out)
    return out
layer = GCNLayer()
x = torch.randn(4,2)
x
tensor([[-2.1208, -0.5321],
        [-0.4745,  0.9924],
        [ 0.1378,  0.6422],
        [-0.6585, -1.2959]])
g = torch.tensor([[0,0,2,3,1],[1,3,1,1,0]],dtype=torch.long)
new_x = layer(g,x)
input node feature tensor([[-2.1208, -0.5321],
        [-0.4745,  0.9924],
        [ 0.1378,  0.6422],
        [-0.6585, -1.2959]])
step1, linear trasnform the node features: tensor([[ 0.1784, -0.3058,  1.1049],
        [-1.1339,  0.2926, -0.6656],
        [-0.9919, -0.0605, -0.6803],
        [ 0.4678, -1.1036,  1.0191]], grad_fn=<AddmmBackward0>)
src_d: tensor([1.0000, 1.0000, 0.0000, 1.0000, 1.7321]) dst_d: tensor([1.7321, 1.0000, 1.7321, 1.7321, 1.0000])
step2 normalized degree: tensor([[1.7321],
        [1.0000],
        [0.0000],
        [1.7321],
        [1.7321]])
step3 get source features: tensor([[ 0.1784, -0.3058,  1.1049],
        [ 0.1784, -0.3058,  1.1049],
        [-0.9919, -0.0605, -0.6803],
        [ 0.4678, -1.1036,  1.0191],
        [-1.1339,  0.2926, -0.6656]], grad_fn=<IndexSelectBackward0>) normalized degree: tensor([[1.7321],
        [1.0000],
        [0.0000],
        [1.7321],
        [1.7321]])
step4 normalized source features: tensor([[ 0.3089, -0.5296,  1.9138],
        [ 0.1784, -0.3058,  1.1049],
        [-0.0000, -0.0000, -0.0000],
        [ 0.8103, -1.9114,  1.7651],
        [-1.9640,  0.5068, -1.1529]], grad_fn=<MulBackward0>)
step5 tensor([[-1.9640,  0.5068, -1.1529],
        [ 1.1192, -2.4411,  3.6790],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.1784, -0.3058,  1.1049]], grad_fn=<ScatterAddBackward0>)
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
bias = Parameter(torch.empty(10))
torch.empty?
bias
Parameter containing:
tensor([3.0695e+37, 4.3868e-41, 4.6949e+20, 3.2917e-41, 1.0282e-14, 4.3868e-41,
        5.6391e+20, 3.2917e-41, 0.0000e+00, 0.0000e+00], requires_grad=True)
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        print('step1',edge_index)
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        print('step2',edge_index)
        print('x',x)
        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        print('row',row)
        print('col',col)
        deg = degree(col, x.size(0), dtype=x.dtype)
        print('deg',deg)

        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm2 = (deg_inv_sqrt[row] * deg_inv_sqrt[col]).unsqueeze(1)
        print('step3 normalize degrees',norm2)

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm2) #会先把value传递给message,然后运用scatter合并
        print(' step5 propage:', out)

        # Step 6: Apply a final bias vector.
        out = out + self.bias

        return out

    def message(self, x_j, norm): # message会分出source j 和target i 以及norm值
        # x_j has shape [E, out_channels]
        print('step4 normalize source feature',x_j,'norm:',norm)
        # Step 4: Normalize node features.
        return norm * x_j
edge_index = g
x = torch.randn(4,2)
conv = GCNConv(2, 3)
out = conv(x, edge_index)
step1 tensor([[0, 0, 2, 3, 1],
        [1, 3, 1, 1, 0]])
step2 tensor([[0, 0, 2, 3, 1, 0, 1, 2, 3],
        [1, 3, 1, 1, 0, 0, 1, 2, 3]])
x tensor([[-2.1208, -0.5321],
        [-0.4745,  0.9924],
        [ 0.1378,  0.6422],
        [-0.6585, -1.2959]])
row tensor([0, 0, 2, 3, 1, 0, 1, 2, 3])
col tensor([1, 3, 1, 1, 0, 0, 1, 2, 3])
deg tensor([2., 4., 1., 2.])
step3 normalize degrees tensor([[0.3536],
        [0.5000],
        [0.5000],
        [0.3536],
        [0.3536],
        [0.5000],
        [0.2500],
        [1.0000],
        [0.5000]])
step4 normalize source feature tensor([[-0.9057, -1.6206,  0.2800],
        [-0.9057, -1.6206,  0.2800],
        [ 0.2167,  0.3905,  0.3159],
        [-0.5749, -1.0340, -0.5348],
        [ 0.0860,  0.1591,  0.6738],
        [-0.9057, -1.6206,  0.2800],
        [ 0.0860,  0.1591,  0.6738],
        [ 0.2167,  0.3905,  0.3159],
        [-0.5749, -1.0340, -0.5348]], grad_fn=<IndexSelectBackward0>) norm: tensor([[0.3536],
        [0.5000],
        [0.5000],
        [0.3536],
        [0.3536],
        [0.5000],
        [0.2500],
        [1.0000],
        [0.5000]])
 step5 propage: tensor([[-0.4224, -0.7541,  0.3782],
        [-0.3936, -0.7035,  0.2363],
        [ 0.2167,  0.3905,  0.3159],
        [-0.7403, -1.3273, -0.1274]], grad_fn=<ScatterAddBackward0>)