{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 3: SpaCon for mouse spatial transcriptomics and widefield functional connectivity data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial demonstrates how to use SpaCon to integrate mouse gene expression and wide‑field calcium imaging data. We first mapped the mouse spatial transcriptomics data onto a cortical atlas and aligned it with the wide‑field calcium imaging data.\n", "\n", "The spatial transcriptomics data used here is the MERFISH dataset from the study published in *Nature*: [Molecularly defined and spatially resolved cell atlas of the whole mouse brain](https://www.nature.com/articles/s41586-023-06808-9). The wide‑field calcium imaging data come from: [Diverse and asymmetric patterns of single‑neuron projectome in regulating interhemispheric connectivity](https://www.nature.com/articles/s41467-024-47762-y).\n", "\n", "We co‑registered both datasets to the same spatial resolution. The processed data for this tutorial can be downloaded from this [Google Drive link](https://drive.google.com/drive/folders/1lQQQVjXt8lvciuKq_Jkp4LGsR_EQXFAi?usp=sharing).\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import spacon\n", "from spacon.utils import build_spatial_graph, build_connection_graph, neighbor_sample, model_train, model_eval, clustering\n", "\n", "from spacon.model import SpaCon\n", "\n", "import datetime\n", "import os\n", "import scanpy as sc\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import numpy as np\n", "import random\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "mus = 'mouse_3'\n", "if mus == 'mouse_1': # coronal\n", " plot_x, plot_y = 'z', 'y'\n", " figsize = (5,5)\n", "elif mus == 'mouse_3': # sagittal\n", " plot_x, plot_y = 'x', 'y'\n", " figsize = (11,5)\n", "\n", "\n", "def set_seed(seed: int):\n", " os.environ['PYTHONHASHSEED'] = str(seed)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", "\n", "set_seed(42)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "**Data preprocessing**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Load spatial transcriptomics data**\n", "\n", "If there are too many genes (for example, more than 5,000), we recommend first screening for highly variable genes using the following method:\n", "\n", "```\n", "n_top_genes = 3000\n", "sc.pp.highly_variable_genes(adata, flavor=\"seurat\", n_top_genes=n_top_genes)\n", "adata = adata[:, adata.var.highly_variable]\n", "```" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 3372 × 1122\n", " obs: 'x', 'y', 'wf_index'\n", " uns: 'log1p'\n", " obsm: 'X_spatial_2d'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata = sc.read_h5ad('/mnt/Data16Tc/home/haichao/code/SpaCon/ST_FC_cluster/mouse1/data/zxw1_wide_field/zxw1_cortical_map_half_brain_match_wf_conn.h5ad') # gene expression has been normalize_total and log1p\n", "adata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Build spatial graph**\n", "\n", "The `build_spatial_graph` function constructs a spatial graph using the three-dimensional spatial coordinates of the spatial transcriptome. The main parameters include:\n", "\n", "- `adata`: Spatial transcriptomics data must include the three-dimensional coordinates for each spot (i.e., the slice number and the two-dimensional coordinates within that slice).\n", "- `section_order`: A slice order list where the sequence represents the original arrangement of each slice within the brain. While the slices can be oriented differently, their relative order must be strictly maintained.\n", "- `rad_cutoff`: Neighborhood radius, each spot will have edges added to all other spots within its neighborhood radius.\n", "- `rad_cutoff_Zaxis`: Inter-slice neighborhood radius, each spot will have edges added to spots in adjacent slices that are within this radius.\n", "- `sec_x`: The column name in `adata.obs` that stores the x-coordinate of each spot within its slice.\n", "- `sec_y`: The column name in `adata.obs` that stores the y-coordinate of each spot within its slice.\n", "- `key_section`: Column name in `adata.obs` that stores the slice number (where different numbers indicate different slices)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Data(x=[3372, 1122], edge_index=[2, 53952])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# calculate the spatial graph for the adata\n", "ST_graph_data, st_adj = build_spatial_graph(adata=adata, k_cutoff=15, model='KNN',\n", " sec_x='y', sec_y='x', is_3d=False)\n", "ST_graph_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Load connectivity data and build connection graph**\n", "\n", "The `build_connection_graph` function uses connection information to construct a three-dimensional connection graph. The main parameters include:\n", "\n", "- `nt_adj`: An n x n two-dimensional matrix, where n is the number of spots in the spatial transcriptomics data, representing the connection strength between spots.\n", "- `threshold`: Filtering threshold, connection strengths below this value will be set to zero." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "distance_weight = True\n", "decay_rate = 0.006\n", "neighbor_weight1 = False\n", "neighbor_weight1_percentage = 30\n", "\n", "if distance_weight:\n", " wf_FC_mouse1 = np.load('/mnt/Data16Tc/home/haichao/code/SpaCon/ST_FC_cluster/mouse1/data/zxw1_wide_field/wf_FC_mouse1_fliter_100um.npy')\n", " coor = np.array(adata.obs[['x', 'y']])\n", " for i in range(wf_FC_mouse1.shape[0]):\n", " distances = np.linalg.norm(coor - coor[i], axis=1)\n", " neighbor = np.percentile(distances, neighbor_weight1_percentage)\n", " \n", " weight = 1/(np.exp(-decay_rate * distances))\n", " # weight = weight/np.max(weight)\n", " if neighbor_weight1:\n", " weight[distances < neighbor] = 1\n", " # weight = weight/np.max(weight)\n", " # print(weight.max())\n", " wf_FC_mouse1[i] = np.multiply(wf_FC_mouse1[i], weight)\n", " # break" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def filter_matrix(mat, thr, per):\n", " n = mat.shape[0]\n", " k_per_row = int(per * n) # Calculate the maximum number of elements to retain per row (150)\n", " filtered_mat = np.zeros_like(mat) # Initialize the filtered matrix\n", "\n", " for i in range(n):\n", " row = mat[i, :].copy() # Copy the current row to avoid modifying the original matrix\n", "\n", " # Step 1: Retain elements greater than 0.7\n", " mask = row > thr\n", " valid_indices = np.where(mask)[0]\n", "\n", " if len(valid_indices) == 0:\n", " continue # No matching elements, skip\n", "\n", " # Step 2: Sort in descending order by value and select the top k elements\n", " valid_values = row[valid_indices]\n", " sorted_indices = np.argsort(-valid_values) # Indices for descending sort\n", "\n", " k = min(k_per_row, len(sorted_indices))\n", " selected = sorted_indices[:k]\n", " selected_indices = valid_indices[selected]\n", "\n", " # Update the filtered matrix\n", " filtered_mat[i, selected_indices] = row[selected_indices]\n", "\n", " # Optional step: Maintain matrix symmetry\n", " # filtered_mat = np.maximum(filtered_mat, filtered_mat.T)\n", "\n", " return filtered_mat\n", "\n", "thr = 0.8\n", "max_retention_each_row = 0.1\n", "wf_FC_mouse1 = filter_matrix(wf_FC_mouse1, thr=thr, per=max_retention_each_row)\n", "# for i in range(wf_FC_mouse1.shape[0]):\n", "# wf_FC_mouse1[i,i] = 2" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.08214357580183748\n" ] } ], "source": [ "wf_FC_mouse1[wf_FC_mouse1 < thr] = 0\n", "count_after = np.count_nonzero(wf_FC_mouse1) \n", "proportion_after = count_after/(wf_FC_mouse1.shape[0]*wf_FC_mouse1.shape[1]) \n", "print(proportion_after)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Data(x=[3372, 1122], edge_index=[2, 934013])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "NT_graph_data = build_connection_graph(adata, wf_FC_mouse1, threshold=thr)\n", "NT_graph_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Neighbor-based subgraph sampling**\n", "\n", "The `neighbor_sample` function performs subgraph sampling from the input spatial graph and connection graph. Its main parameters include:\n", "\n", "- `batch_size`: The batch size for model training.\n", "- `train_num_neighbors`: The number of neighbors to sample for each node in each iteration. This parameter is used by the data loader during the model training process.\n", "- `eval_num_neighbors`: The number of neighbors to sample for each node in each iteration. This parameter is used by the data loader during the model evaluation process. If an entry is set to -1, all neighbors will be included.(default:`[-1]`)\n", "\n", "The function returns three data loaders: `train_loader`, `evaluate_loader_con`, and `evaluate_loader_spa`. The `train_loader` is used during the **model training process**. Meanwhile, `evaluate_loader_con` and `evaluate_loader_spa` are used for **model evaluation** on the **connection graph** and **spatial graph**, respectively." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_loader, evaluate_loader_con, evaluate_loader_spa = neighbor_sample(NT_graph_data, ST_graph_data, batch_size=64, train_num_neighbors=[20, 10, 10], num_workers=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model training**" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch:1|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 26.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:2|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 32.05it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:3|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 29.43it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:4|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 30.42it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:5|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 28.31it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:6|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 29.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:7|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 30.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:8|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 30.99it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:9|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 32.25it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch:10|10\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 53/53 [00:01<00:00, 31.54it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Training completed! The model parameters have been saved to ./results_widefield/2025_07_24_14_06_20/model_params.pth\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", "\n", "# hyper-parameters\n", "num_epoch = 10\n", "lr = 0.0001\n", "weight_decay = 1e-4\n", "hidden_dims = [adata.X.shape[1]] + [256, 128, 32] \n", "# model\n", "# fusion_method indicates the feature fusion method of the middle layer, you can choose 'add' or 'concat'\n", "model = SpaCon(hidden_dims=hidden_dims, fusion_method='concat').to(device)\n", "# if model_save_path=None, the model will not be saved\n", "results_save_path = f\"./results_widefield/{str(datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))}/\"\n", "os.makedirs(results_save_path, exist_ok=True)\n", "\n", "model = model_train(num_epoch, lr, weight_decay, model, train_loader, st_adj, model_save_path=results_save_path, device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Model evaluation**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The features obtained after model dimensionality reduction, named `feature_spa` and `feature_con`, are stored in the returned `adata.obsm`. These features can be used for subsequent cluster analysis." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Evaluating: 100%|██████████| 10116/10116 [00:01<00:00, 6952.61it/s]\n", "Evaluating: 100%|██████████| 10116/10116 [00:01<00:00, 8215.90it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The results have been saved in adata.obsm\n", "AnnData object with n_obs × n_vars = 3372 × 1122\n", " obs: 'x', 'y', 'wf_index'\n", " uns: 'log1p'\n", " obsm: 'X_spatial_2d', 'feature_spa', 'feature_con'\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "adata = model_eval(model, adata, NT_graph_data, ST_graph_data, evaluate_loader_con, evaluate_loader_spa, st_adj, layer_eval=True, device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Clustering**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `clustering` function performs clustering using the louvain algorithm, with the following key parameters:\n", "\n", "* `adata`: The AnnData object obtained previously, which contains the clustering features (`feature_spa`, `feature_con`).\n", "* `alpha`: This parameter adjusts the contribution of local spatial information versus global connection information in the clustering results.\n", " * When `alpha = 1`, the clustering will incorporate more global information.\n", " * When `alpha = 0`, the clustering will focus more on local information.\n", " You can set different `alpha` values based on your downstream tasks.\n", "* `adata_save_path`: The path where the results will be saved.\n", "* `cluster_resolution`: The clustering resolution used during the louvain clustering process.\n", "\n", "The returned `path` indicates where the clustering results are saved." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | x | \n", "y | \n", "wf_index | \n", "
|---|---|---|---|
| 0 | \n", "41 | \n", "20 | \n", "22 | \n", "
| 1 | \n", "42 | \n", "20 | \n", "23 | \n", "
| 2 | \n", "43 | \n", "20 | \n", "24 | \n", "
| 3 | \n", "44 | \n", "20 | \n", "25 | \n", "
| 4 | \n", "45 | \n", "20 | \n", "26 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "
| 3367 | \n", "25 | \n", "103 | \n", "1025 | \n", "
| 3368 | \n", "26 | \n", "103 | \n", "1047 | \n", "
| 3369 | \n", "27 | \n", "103 | \n", "1071 | \n", "
| 3370 | \n", "28 | \n", "103 | \n", "1096 | \n", "
| 3371 | \n", "29 | \n", "103 | \n", "1122 | \n", "
3372 rows × 3 columns
\n", "