Path Progression on a Track Graph (Bandit Task Real Data)¶
This notebook demonstrates goal-to-goal trial segmentation and path
progression analysis on a real hippocampal recording from the same
bandit-task dataset used in 19_real_data_bandit_task.ipynb.
Dataset: J16 session, 2021-07-10 (triple-Y maze, 3 reward patches, 6 well endpoints, 24 minutes of position tracking at 500 Hz).
What you'll learn
- How to define goal regions from a track graph's endpoint nodes
- How to detect well-entry events and pair them into goal-to-goal trials
- How to compute normalized path progress (0 -> 1) along the track graph
- How to visualize per-trial and aggregate progression
Pipeline
track graph + positions
|
v
Environment.from_graph (1D linearized, but still queryable in 2D)
|
v
env.regions.buffer(...) for each endpoint = well region
|
v
detect_region_crossings(direction="entry") for each well
|
v
pair consecutive visits -> trials
|
v
path_progress(metric="geodesic")
import importlib.util
import itertools
from collections import Counter
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Circle
from neurospatial import Environment
from neurospatial.behavior import path_progress
from neurospatial.behavior.segmentation import detect_region_crossings
def _find_project_root(start: Path) -> Path:
"""Find the project root containing the real-data helper."""
for path in (start, *start.parents):
if (path / "data" / "load_bandit_data.py").exists():
return path
raise FileNotFoundError(f"Could not find data/load_bandit_data.py from {start}")
try:
_start_path = Path(__file__).resolve().parent
except NameError:
_start_path = Path.cwd().resolve()
_base_path = _find_project_root(_start_path)
_loader_spec = importlib.util.spec_from_file_location(
"load_bandit_data", _base_path / "data" / "load_bandit_data.py"
)
if _loader_spec is None or _loader_spec.loader is None:
raise ImportError("Could not load bandit data helper")
_loader_module = importlib.util.module_from_spec(_loader_spec)
_loader_spec.loader.exec_module(_loader_module)
load_neural_recording_from_files = _loader_module.load_neural_recording_from_files
# Shared styling (Okabe-Ito palette, consistent figure / font sizes)
import sys
_here = (
str(Path(__file__).resolve().parent) if "__file__" in globals() else str(Path.cwd())
)
if _here not in sys.path:
sys.path.insert(0, _here)
from _style import apply_style
apply_style(figsize=(12, 8), font_size=11)
1. Load the bandit-task recording¶
We only need the position trace and the track graph for this notebook; spike data is ignored.
data = load_neural_recording_from_files(_base_path / "data", "j1620210710_02_r1")
position_info = data["position_info"]
track_graph = data["track_graph"]
linear_edge_order = data["linear_edge_order"]
linear_edge_spacing = data["linear_edge_spacing"]
times = position_info.index.values
positions = position_info[["head_position_x", "head_position_y"]].values
duration = times[-1] - times[0]
print(f"Recording: {len(times):,} samples, {duration / 60:.1f} min")
print(
f"Track graph: {track_graph.number_of_nodes()} nodes, "
f"{track_graph.number_of_edges()} edges"
)
Recording: 709,321 samples, 23.6 min Track graph: 10 nodes, 9 edges
2. Build the environment from the track graph¶
Environment.from_graph creates a 1D linearized environment whose bins
live along the track. The env is still queryable with 2D points
(env.bin_at(positions_2d)) -- coordinates are snapped to the nearest
track segment.
env = Environment.from_graph(
graph=track_graph,
edge_order=linear_edge_order,
edge_spacing=linear_edge_spacing,
bin_size=2.0,
name="bandit_track",
)
env.units = "cm"
print(f"Environment: {env.n_bins} bins along the linearized track")
print(f"is_linearized_track: {env.is_linearized_track}")
Environment: 248 bins along the linearized track is_linearized_track: True
3. Define the wells as goal regions¶
The maze is a triple-Y (3 arms x 2 wells per arm = 6 wells). The 6 wells
correspond to the endpoint nodes of the track graph (those with
degree 1). For each endpoint we create a small buffered polygon
region; path_progress's segmentation step needs polygons, since point
regions have zero area and regions_to_mask would return empty masks.
endpoint_nodes = sorted(n for n in track_graph.nodes if track_graph.degree(n) == 1)
well_radius_cm = 8.0
well_names: list[str] = []
for n in endpoint_nodes:
name = f"well_{n}"
env.regions.buffer(
np.asarray(track_graph.nodes[n]["pos"], dtype=float),
well_radius_cm,
name,
)
well_names.append(name)
print(f"Defined {len(well_names)} wells (radius={well_radius_cm} cm):")
for w in well_names:
bins = env.bins_in_region(w)
print(f" {w}: {len(bins)} track bins")
Defined 6 wells (radius=8.0 cm): well_0: 4 track bins well_1: 4 track bins well_2: 4 track bins well_3: 4 track bins well_4: 4 track bins well_5: 4 track bins
Plot the maze with wells highlighted to confirm the regions land on the track endpoints.
fig, ax = plt.subplots(figsize=(9, 9), constrained_layout=True)
# Track edges
for u, v in track_graph.edges():
pu = np.array(track_graph.nodes[u]["pos"])
pv = np.array(track_graph.nodes[v]["pos"])
ax.plot([pu[0], pv[0]], [pu[1], pv[1]], "k-", lw=4, alpha=0.3, zorder=1)
# Trajectory (subsampled)
ax.plot(
positions[::40, 0],
positions[::40, 1],
color="0.6",
lw=0.3,
alpha=0.4,
zorder=2,
label="trajectory",
)
# Well regions (drawn as circles at the endpoint node positions, matching
# the buffer radius used to build the region polygons).
for n, w in zip(endpoint_nodes, well_names, strict=True):
cx, cy = track_graph.nodes[n]["pos"]
ax.add_patch(
Circle(
(cx, cy),
well_radius_cm,
facecolor="tab:orange",
edgecolor="black",
lw=1.5,
alpha=0.6,
zorder=5,
)
)
ax.annotate(
w.replace("well_", "W"),
(cx, cy),
fontsize=10,
fontweight="bold",
ha="center",
va="center",
zorder=6,
)
ax.set_aspect("equal")
ax.set_xlabel("x (cm)")
ax.set_ylabel("y (cm)")
ax.set_title("Triple-Y bandit maze: trajectory and 6 well regions")
ax.legend(loc="upper right")
plt.show()
4. Detect well-entry events¶
detect_region_crossings(..., direction="entry") returns the first
sample inside each region per visit. We collect entries from every well
and merge them chronologically into one visit list.
position_bins = env.bin_at(positions)
visits: list[tuple[float, str, int]] = []
for w in well_names:
for cr in detect_region_crossings(position_bins, times, w, env, direction="entry"):
visits.append((cr.time, w, int(cr.bin_index)))
visits.sort(key=lambda v: v[0])
print(f"Total well-entry events: {len(visits)}")
print("Visits per well:")
for w, n in Counter(v[1] for v in visits).most_common():
print(f" {w}: {n}")
Total well bouts: 191 Bouts per well: well_4: 70 well_5: 66 well_1: 18 well_0: 15 well_2: 13 well_3: 9
5. Pair consecutive visits into goal-to-goal trials¶
Trial i = from visit i to visit i+1:
t0 = visit_i.time(animal enters well A)t1 = visit_{i+1}.time(animal enters well B)start_bin = A's entry bin,goal_bin = B's entry bin
Self-transitions (consecutive entries to the same well, usually a brief excursion and re-entry) are dropped.
Interpretation note. With this convention each trial spans the full
epoch from "animal arrived at A" to "animal arrived at B". For long
trials the animal often sits at A for several seconds (consuming
reward, sniffing the patch) before committing to the travel; during
that period path_progress correctly reports a value near 0, because
the animal hasn't actually moved toward B yet. The per-trial curves in
section 8 will show this as a flat segment followed by a ramp -- that's
real behaviour, not a bug.
trials = []
for (t0, a, ba), (t1, b, bb) in itertools.pairwise(visits):
if a == b:
continue
trials.append(
{"t0": t0, "t1": t1, "start": a, "goal": b, "start_bin": ba, "goal_bin": bb}
)
durations = np.array([tr["t1"] - tr["t0"] for tr in trials])
print(f"Trials: {len(trials)} (after dropping self-transitions)")
print(
f"Duration (s): min={durations.min():.2f} "
f"median={np.median(durations):.2f} max={durations.max():.2f}"
)
# Transition matrix (start well -> goal well counts)
labels = well_names
idx = {w: i for i, w in enumerate(labels)}
M = np.zeros((len(labels), len(labels)), dtype=int)
for tr in trials:
M[idx[tr["start"]], idx[tr["goal"]]] += 1
fig, ax = plt.subplots(figsize=(7, 6), constrained_layout=True)
im = ax.imshow(M, cmap="viridis")
ax.set_xticks(range(len(labels)))
ax.set_yticks(range(len(labels)))
ax.set_xticklabels([w.replace("well_", "W") for w in labels])
ax.set_yticklabels([w.replace("well_", "W") for w in labels])
ax.set_xlabel("goal well")
ax.set_ylabel("start well")
ax.set_title("Transition counts (start -> goal)")
for i in range(len(labels)):
for j in range(len(labels)):
if M[i, j]:
ax.text(
j,
i,
M[i, j],
ha="center",
va="center",
color="white" if M[i, j] < M.max() / 2 else "black",
fontsize=10,
)
plt.colorbar(im, ax=ax, label="n trials")
plt.show()
[(1625935762.2108786, 1625935769.1908762, 'well_2', 131), (1625935773.7528746, 1625935780.2448723, 'well_3', 159), (1625935782.1428716, 1625935785.0488706, 'well_2', 131), (1625935785.2828705, 1625935785.4088705, 'well_2', 131), (1625935789.2188692, 1625935799.2488656, 'well_1', 76), (1625935801.014865, 1625935810.6168616, 'well_0', 50), (1625935812.3868608, 1625935821.6208577, 'well_1', 76), (1625935823.330857, 1625935826.1428561, 'well_0', 50), (1625935827.8708553, 1625935829.8928547, 'well_1', 76), (1625935833.6028533, 1625935841.7948503, 'well_4', 216), (1625935843.7268498, 1625935861.3708434, 'well_5', 244), (1625935863.4548428, 1625935871.48684, 'well_4', 216), (1625935873.3708394, 1625935875.4168386, 'well_5', 244), (1625935877.4208379, 1625935886.6468346, 'well_4', 216), (1625935888.590834, 1625935896.1588311, 'well_5', 244), (1625935897.9068305, 1625935905.7028277, 'well_4', 216), (1625935907.776827, 1625935909.2988265, 'well_5', 244), (1625935911.4088259, 1625935918.4448233, 'well_4', 216), (1625935922.7508218, 1625935924.678821, 'well_2', 131), (1625935926.3708205, 1625935930.202819, 'well_3', 159), (1625935933.256818, 1625935936.736817, 'well_2', 131), (1625935938.7628162, 1625935940.4888155, 'well_3', 159), (1625935942.6388147, 1625935945.080814, 'well_2', 131), (1625935948.6308126, 1625935954.7288103, 'well_1', 76), (1625935956.3088098, 1625935961.9148078, 'well_0', 50), (1625935963.7688072, 1625935973.052804, 'well_1', 76), (1625935974.9148033, 1625935979.1948018, 'well_0', 50), (1625935980.948801, 1625935982.9268005, 'well_1', 76), (1625935984.7508, 1625935987.3507988, 'well_0', 50), (1625935989.1307983, 1625935991.5027974, 'well_1', 76), (1625935995.352796, 1625935998.9187949, 'well_5', 244), (1625936000.7207942, 1625936009.622791, 'well_4', 216), (1625936011.6487904, 1625936015.9567888, 'well_5', 244), (1625936017.7247882, 1625936026.498785, 'well_4', 216), (1625936028.948784, 1625936036.0987816, 'well_5', 244), (1625936037.864781, 1625936046.9747777, 'well_4', 216), (1625936049.212777, 1625936051.4387763, 'well_5', 244), (1625936053.2187755, 1625936062.2367723, 'well_4', 216), (1625936064.7327714, 1625936076.3907673, 'well_5', 244), (1625936079.1327662, 1625936087.1487634, 'well_4', 216), (1625936087.4327633, 1625936087.6527634, 'well_4', 216), (1625936090.732762, 1625936092.1747618, 'well_5', 244), (1625936094.220761, 1625936103.016758, 'well_4', 216), (1625936105.0727572, 1625936107.8347561, 'well_5', 244), (1625936109.8487554, 1625936119.6207519, 'well_4', 216), (1625936121.4887514, 1625936125.6747499, 'well_5', 244), (1625936127.922749, 1625936135.9947462, 'well_4', 216), (1625936138.4767454, 1625936151.7807405, 'well_5', 244), (1625936153.58474, 1625936162.174737, 'well_4', 216), (1625936164.0567362, 1625936166.8087351, 'well_5', 244), (1625936168.7287345, 1625936179.7627306, 'well_4', 216), (1625936181.7767298, 1625936191.1267266, 'well_5', 244), (1625936193.172726, 1625936200.8287232, 'well_4', 216), (1625936202.7347226, 1625936204.4067218, 'well_5', 244), (1625936206.3527212, 1625936208.2347205, 'well_4', 216), (1625936210.16672, 1625936221.5367157, 'well_5', 244), (1625936223.480715, 1625936230.3527126, 'well_4', 216), (1625936232.150712, 1625936233.7187116, 'well_5', 244), (1625936235.7527108, 1625936238.3987098, 'well_4', 216), (1625936240.408709, 1625936242.0867085, 'well_5', 244), (1625936244.048708, 1625936255.340704, 'well_4', 216), (1625936257.5787032, 1625936259.6927023, 'well_5', 244), (1625936261.4687016, 1625936264.6527007, 'well_4', 216), (1625936266.6007, 1625936278.5146956, 'well_5', 244), (1625936280.4726949, 1625936282.3526943, 'well_4', 216), (1625936284.2226937, 1625936286.176693, 'well_5', 244), (1625936288.1286922, 1625936290.9826913, 'well_4', 216), (1625936294.7966897, 1625936298.4466887, 'well_5', 244), (1625936300.294688, 1625936308.8346848, 'well_4', 216), (1625936310.7806842, 1625936324.0066795, 'well_5', 244), (1625936325.856679, 1625936327.7426782, 'well_4', 216), (1625936329.6666775, 1625936331.446677, 'well_5', 244), (1625936333.3506763, 1625936335.6966753, 'well_4', 216), (1625936337.7286747, 1625936339.544674, 'well_5', 244), (1625936341.6726732, 1625936344.0046725, 'well_4', 216), (1625936347.676671, 1625936355.3346684, 'well_1', 76), (1625936357.0546677, 1625936361.3586662, 'well_0', 50), (1625936363.7266655, 1625936365.8046646, 'well_1', 76), (1625936367.642664, 1625936373.322662, 'well_0', 50), (1625936375.5226612, 1625936382.7406588, 'well_1', 76), (1625936384.764658, 1625936387.366657, 'well_0', 50), (1625936389.1386564, 1625936396.4946537, 'well_1', 76), (1625936398.3966532, 1625936400.6466522, 'well_0', 50), (1625936402.2846518, 1625936404.324651, 'well_1', 76), (1625936406.2526503, 1625936413.6526477, 'well_0', 50), (1625936414.4066474, 1625936414.7806473, 'well_0', 50), (1625936416.6006467, 1625936425.0826437, 'well_1', 76), (1625936429.0026424, 1625936432.508641, 'well_4', 216), (1625936434.6766403, 1625936437.3066394, 'well_5', 244), (1625936439.2866387, 1625936441.9306376, 'well_4', 216), (1625936444.2106369, 1625936446.0186362, 'well_5', 244), (1625936448.2446356, 1625936453.2166338, 'well_4', 216), (1625936456.9866323, 1625936459.6526315, 'well_1', 76), (1625936461.6106308, 1625936467.5406287, 'well_0', 50), (1625936469.386628, 1625936476.0806255, 'well_1', 76), (1625936478.2326248, 1625936481.4466238, 'well_0', 50), (1625936483.404623, 1625936485.3586223, 'well_1', 76), (1625936487.2026217, 1625936489.7186208, 'well_0', 50), (1625936491.4566202, 1625936500.406617, 'well_1', 76), (1625936502.4626162, 1625936505.764615, 'well_0', 50), (1625936507.5786145, 1625936514.490612, 'well_1', 76), (1625936523.3406088, 1625936525.9966078, 'well_2', 131), (1625936526.2346077, 1625936526.8126075, 'well_2', 131), (1625936529.186607, 1625936536.1106043, 'well_3', 159), (1625936539.0186033, 1625936541.3706024, 'well_2', 131), (1625936543.2946017, 1625936546.5666006, 'well_3', 159), (1625936549.3445997, 1625936555.1925976, 'well_2', 131), (1625936557.146597, 1625936559.1265962, 'well_3', 159), (1625936561.1145954, 1625936567.776593, 'well_2', 131), (1625936572.9505913, 1625936575.3705904, 'well_3', 159), (1625936577.2825897, 1625936579.386589, 'well_2', 131), (1625936581.6745882, 1625936590.062585, 'well_3', 159), (1625936592.3125844, 1625936594.4065835, 'well_2', 131), (1625936596.172583, 1625936604.0965803, 'well_3', 159), (1625936608.2845788, 1625936610.750578, 'well_4', 216), (1625936611.2705777, 1625936611.8145776, 'well_4', 216), (1625936616.008576, 1625936623.4925733, 'well_5', 244), (1625936625.5285726, 1625936629.7245712, 'well_4', 216), (1625936632.0625703, 1625936635.2085693, 'well_5', 244), (1625936637.1365685, 1625936646.9005651, 'well_4', 216), (1625936649.576564, 1625936651.9325633, 'well_5', 244), (1625936653.8085625, 1625936662.1485596, 'well_4', 216), (1625936664.3105588, 1625936667.3105578, 'well_5', 244), (1625936669.170557, 1625936671.7305562, 'well_4', 216), (1625936673.8165555, 1625936675.504555, 'well_5', 244), (1625936677.6005542, 1625936680.2305532, 'well_4', 216), (1625936683.834552, 1625936693.3345485, 'well_5', 244), (1625936695.358548, 1625936697.2825472, 'well_4', 216), (1625936699.1925466, 1625936701.0625458, 'well_5', 244), (1625936703.102545, 1625936713.5585413, 'well_4', 216), (1625936715.6045406, 1625936717.59854, 'well_5', 244), (1625936719.5845392, 1625936729.3325357, 'well_4', 216), (1625936731.590535, 1625936733.6225343, 'well_5', 244), (1625936735.5005336, 1625936745.4265301, 'well_4', 216), (1625936747.6425292, 1625936750.1365285, 'well_5', 244), (1625936752.0185277, 1625936754.286527, 'well_4', 216), (1625936756.2805262, 1625936758.5425255, 'well_5', 244), (1625936761.0685246, 1625936771.622521, 'well_4', 216), (1625936773.83852, 1625936775.2345195, 'well_5', 244), (1625936777.5905187, 1625936792.3065135, 'well_4', 216), (1625936794.4705126, 1625936800.5945106, 'well_5', 244), (1625936802.5565097, 1625936804.912509, 'well_4', 216), (1625936807.0445082, 1625936808.6265078, 'well_5', 244), (1625936810.7805068, 1625936820.6665034, 'well_4', 216), (1625936823.2825024, 1625936835.0084984, 'well_5', 244), (1625936843.5284953, 1625936845.7844946, 'well_4', 216), (1625936847.6604939, 1625936849.3824933, 'well_5', 244), (1625936851.4824924, 1625936853.5644917, 'well_4', 216), (1625936855.674491, 1625936857.2264905, 'well_5', 244), (1625936859.3124897, 1625936869.3084862, 'well_4', 216), (1625936871.5464854, 1625936876.1904838, 'well_5', 244), (1625936878.300483, 1625936891.1204784, 'well_4', 216), (1625936893.5924776, 1625936895.298477, 'well_5', 244), (1625936897.3924763, 1625936899.8424754, 'well_4', 216), (1625936902.0944746, 1625936904.4944737, 'well_5', 244), (1625936906.448473, 1625936909.808472, 'well_4', 216), (1625936912.028471, 1625936924.4644666, 'well_5', 244), (1625936926.4744658, 1625936928.4644651, 'well_4', 216), (1625936930.3724644, 1625936931.9884639, 'well_5', 244), (1625936934.1064632, 1625936936.8504622, 'well_4', 216), (1625936939.1904614, 1625936940.7284608, 'well_5', 244), (1625936942.7404602, 1625936958.9264543, 'well_4', 216), (1625936961.1404536, 1625936962.856453, 'well_5', 244), (1625936964.8624523, 1625936976.4124482, 'well_4', 216), (1625936979.1204472, 1625936983.7704456, 'well_5', 244), (1625937011.5924358, 1625937015.6324344, 'well_4', 216), (1625937017.6644335, 1625937030.8704288, 'well_5', 244), (1625937030.9424288, 1625937031.2724288, 'well_5', 244), (1625937033.3204281, 1625937043.0424247, 'well_4', 216), (1625937045.072424, 1625937047.364423, 'well_5', 244), (1625937049.1984224, 1625937059.7124188, 'well_4', 216), (1625937061.686418, 1625937064.506417, 'well_5', 244), (1625937066.3244164, 1625937075.590413, 'well_4', 216), (1625937077.5604124, 1625937080.3204114, 'well_5', 244), (1625937082.2204106, 1625937092.0204072, 'well_4', 216), (1625937094.0784066, 1625937095.780406, 'well_5', 244), (1625937097.974405, 1625937100.4344041, 'well_4', 216), (1625937102.4464035, 1625937104.2784028, 'well_5', 244), (1625937106.5104022, 1625937109.442401, 'well_4', 216), (1625937111.4644003, 1625937113.1163998, 'well_5', 244), (1625937115.208399, 1625937123.3363962, 'well_4', 216), (1625937125.5903952, 1625937131.4123933, 'well_5', 244), (1625937133.3243926, 1625937135.8503916, 'well_4', 216), (1625937137.926391, 1625937139.5143905, 'well_5', 244), (1625937141.6683896, 1625937151.810386, 'well_4', 216), (1625937153.9663854, 1625937155.7623847, 'well_5', 244), (1625937163.308382, 1625937168.92838, 'well_5', 244), (1625937169.03038, 1625937169.3063798, 'well_5', 244), (1625937171.412379, 1625937174.110378, 'well_4', 216), (1625937174.8503778, 1625937175.2403777, 'well_4', 216), (1625937175.3543777, 1625937176.2423773, 'well_4', 216)]
6. Compute path progression¶
Build per-timepoint start_bins / goal_bins arrays from the trial
list, then call path_progress. A half-open mask [t0, t1) keeps
adjacent trials from overwriting each other's boundary timestamp (the
entry into well B is the end of trial A->B and the start of trial
B->C, but each timepoint can belong to only one trial).
metric="geodesic" is essential on a track graph: it measures distance
along the connectivity graph, so going from one arm to another correctly
routes through the centre node instead of cutting across the linearized
gaps between arms.
start_bins = np.full(len(times), -1, dtype=np.int_)
goal_bins = np.full(len(times), -1, dtype=np.int_)
for tr in trials:
mask = (times >= tr["t0"]) & (times < tr["t1"])
start_bins[mask] = tr["start_bin"]
goal_bins[mask] = tr["goal_bin"]
progress = path_progress(
position_bins,
env,
start_bins=start_bins,
goal_bins=goal_bins,
metric="geodesic",
)
print(f"progress range: [{np.nanmin(progress):.3f}, {np.nanmax(progress):.3f}]")
print(f"timepoints inside a trial: {(~np.isnan(progress)).sum():,} / {len(progress):,}")
progress range: [0.009, 0.991] timepoints inside a trial: 217,735 / 709,321
7. Linear position and path progression over the session¶
Stacked time-series view: linear position along the track on top, path progression on bottom (shared x-axis). Each trial appears as a 0 -> 1 ramp in the bottom row; the corresponding linear-position trace shows the actual arm the animal traversed during that ramp. Gaps in the progression trace are timepoints outside any trial (NaN).
linear_pos = position_info["linear_position"].values
t_rel = times - times[0]
# Subsample for plotting speed without dropping trial structure (~30 Hz).
SUB = 17
fig, axes = plt.subplots(2, 1, figsize=(14, 7), sharex=True, constrained_layout=True)
ax = axes[0]
ax.plot(t_rel[::SUB], linear_pos[::SUB], color="steelblue", lw=0.4)
# Mark each well's linear coordinate so it's easy to see which arm the
# animal is on at any time.
for w in well_names:
well_bins = env.bins_in_region(w)
if len(well_bins) == 0:
continue
# Bin centres for graph envs are 1-D linear coordinates.
y = float(env.bin_centers[well_bins[0], 0])
ax.axhline(y, color="tab:orange", lw=0.6, alpha=0.5)
ax.text(
t_rel[-1] * 1.005,
y,
w.replace("well_", "W"),
fontsize=8,
va="center",
color="tab:orange",
)
ax.set_ylabel("linear position (cm)")
ax.set_title("Linear position over the session")
ax = axes[1]
ax.plot(t_rel[::SUB], progress[::SUB], color="tab:blue", lw=0.4)
ax.axhline(1.0, color="tab:gray", lw=0.5, ls="--")
ax.set_xlabel("time (s, relative to session start)")
ax.set_ylabel("path progress")
ax.set_ylim(-0.05, 1.1)
ax.set_title("Path progression (NaN outside trials)")
plt.show()
Same two rows zoomed to a ~3-minute window so individual trials are resolvable -- each ramp on the bottom row should align with a sweep across linear-position values on the top row.
zoom_start = float(trials[0]["t0"] - times[0])
zoom_end = zoom_start + 180.0 # 3-minute window
zoom_mask = (t_rel >= zoom_start) & (t_rel <= zoom_end)
fig, axes = plt.subplots(2, 1, figsize=(14, 6), sharex=True, constrained_layout=True)
ax = axes[0]
ax.plot(t_rel[zoom_mask], linear_pos[zoom_mask], color="steelblue", lw=0.8)
for w in well_names:
well_bins = env.bins_in_region(w)
if len(well_bins) == 0:
continue
y = float(env.bin_centers[well_bins[0], 0])
ax.axhline(y, color="tab:orange", lw=0.6, alpha=0.5)
ax.set_ylabel("linear position (cm)")
ax.set_title(f"Zoomed window: {zoom_start:.0f}-{zoom_end:.0f} s")
ax = axes[1]
ax.plot(t_rel[zoom_mask], progress[zoom_mask], color="tab:blue", lw=0.8)
# Shade each trial in the window for visual reference.
for tr in trials:
t0r = tr["t0"] - times[0]
t1r = tr["t1"] - times[0]
if t1r < zoom_start or t0r > zoom_end:
continue
ax.axvspan(t0r, t1r, color="tab:blue", alpha=0.07)
ax.axhline(1.0, color="tab:gray", lw=0.5, ls="--")
ax.set_xlabel("time (s, relative to session start)")
ax.set_ylabel("path progress")
ax.set_ylim(-0.05, 1.1)
plt.show()
8. Per-trial progress curves¶
Plot the first nine trials' progress over time. A clean goal-directed run shows a near-monotonic rise from 0 to ~1. Plateaus near 0 indicate the animal lingering at the start well; non-monotonic dips mean the animal looped back partway.
N_SHOW = 9
fig, axes = plt.subplots(3, 3, figsize=(13, 9), sharey=True, constrained_layout=True)
for ax, tr in zip(axes.flat, trials[:N_SHOW], strict=False):
mask = (times >= tr["t0"]) & (times < tr["t1"])
t = times[mask] - tr["t0"]
p = progress[mask]
ax.plot(t, p, color="tab:blue", lw=1.5)
ax.axhline(1.0, color="tab:gray", lw=0.5, ls="--")
ax.set_ylim(-0.05, 1.1)
ax.set_title(f"{tr['start']} -> {tr['goal']} ({t[-1]:.1f}s)", fontsize=10)
ax.set_xlabel("t since start (s)")
ax.set_ylabel("progress")
fig.suptitle("Per-trial path progress (first 9 trials)", fontsize=13, fontweight="bold")
plt.show()
8. Aggregate progress vs. normalized trial time¶
Resample each trial onto a common time axis [0, 1] and average across
trials. The expected shape is a sigmoid-ish ramp -- animals tend to
spend a fraction of trial time near the start (sniffing, consuming
reward) before committing to a directed run.
N_GRID = 100
t_norm = np.linspace(0.0, 1.0, N_GRID)
per_trial = np.full((len(trials), N_GRID), np.nan)
for i, tr in enumerate(trials):
mask = (times >= tr["t0"]) & (times < tr["t1"])
if mask.sum() < 2:
continue
t_local = (times[mask] - tr["t0"]) / max(tr["t1"] - tr["t0"], 1e-9)
p = progress[mask]
valid = ~np.isnan(p)
if valid.sum() < 2:
continue
per_trial[i] = np.interp(t_norm, t_local[valid], p[valid])
mean_p = np.nanmean(per_trial, axis=0)
lo_p = np.nanpercentile(per_trial, 25, axis=0)
hi_p = np.nanpercentile(per_trial, 75, axis=0)
fig, axes = plt.subplots(1, 2, figsize=(13, 5), constrained_layout=True)
ax = axes[0]
ax.hist(durations, bins=30, color="steelblue", edgecolor="black", alpha=0.8)
ax.axvline(
np.median(durations),
color="red",
lw=2,
ls="--",
label=f"median = {np.median(durations):.1f}s",
)
ax.set_xlabel("trial duration (s)")
ax.set_ylabel("n trials")
ax.set_title(f"Trial duration distribution ({len(trials)} trials)")
ax.legend()
ax = axes[1]
ax.fill_between(t_norm, lo_p, hi_p, color="tab:blue", alpha=0.2, label="IQR (25-75%)")
ax.plot(t_norm, mean_p, color="tab:blue", lw=2.5, label="mean")
ax.plot([0, 1], [0, 1], color="tab:gray", lw=1, ls=":", label="ideal linear ramp")
ax.set_xlabel("normalized trial time")
ax.set_ylabel("path progress")
ax.set_title("Average progress across all trials")
ax.set_xlim(0, 1)
ax.set_ylim(-0.05, 1.1)
ax.legend(loc="lower right")
plt.show()
9. Spatial view: one trial colored by progress¶
Plot one trial's 2D trajectory with each sample coloured by its progress value. This makes the geodesic distance interpretation concrete -- progress along the path, not Euclidean distance to goal. The trajectory begins inside the start well (progress ~ 0) and ends inside the goal well (progress ~ 1).
# Pick a representative trial: one with a clean, longer-than-median duration.
median_dur = float(np.median(durations))
clean_trials = [
tr for tr in trials if median_dur < tr["t1"] - tr["t0"] < 3 * median_dur
]
demo_trial = clean_trials[0] if clean_trials else trials[0]
mask = (times >= demo_trial["t0"]) & (times < demo_trial["t1"])
xy = positions[mask]
p = progress[mask]
fig, ax = plt.subplots(figsize=(9, 9), constrained_layout=True)
# Track edges
for u, v in track_graph.edges():
pu = np.array(track_graph.nodes[u]["pos"])
pv = np.array(track_graph.nodes[v]["pos"])
ax.plot([pu[0], pv[0]], [pu[1], pv[1]], "k-", lw=4, alpha=0.3, zorder=1)
# Wells (faded) - drawn as circles at endpoint node positions
for n in endpoint_nodes:
cx, cy = track_graph.nodes[n]["pos"]
ax.add_patch(
Circle(
(cx, cy),
well_radius_cm,
facecolor="tab:orange",
edgecolor="black",
lw=1.0,
alpha=0.25,
zorder=2,
)
)
# Trajectory coloured by progress
sc = ax.scatter(
xy[:, 0], xy[:, 1], c=p, cmap="viridis", vmin=0.0, vmax=1.0, s=8, zorder=5
)
ax.scatter(
*xy[0],
color="white",
edgecolor="black",
s=140,
marker="o",
linewidths=1.5,
zorder=10,
label=f"start ({demo_trial['start']})",
)
ax.scatter(
*xy[-1],
color="black",
edgecolor="white",
s=140,
marker="s",
linewidths=1.5,
zorder=10,
label=f"goal ({demo_trial['goal']})",
)
ax.set_aspect("equal")
ax.set_xlabel("x (cm)")
ax.set_ylabel("y (cm)")
ax.set_title(
f"Trial {demo_trial['start']} -> {demo_trial['goal']} "
f"({demo_trial['t1'] - demo_trial['t0']:.1f}s)"
)
plt.colorbar(sc, ax=ax, label="path progress")
ax.legend(loc="upper right")
plt.show()
Summary¶
Recipe for goal-to-goal path progression on a track graph:
- Environment from
track_graphviaEnvironment.from_graph(...). - Goal regions from endpoint nodes via
env.regions.buffer(point, radius, name)(must be polygons, not points). - Detect entries to each well with
detect_region_crossings(..., direction="entry")and merge events chronologically. - Pair consecutive visits into trials; drop self-transitions.
- Build per-timepoint arrays
start_bins[t],goal_bins[t]using a half-open mask[t0, t1)so adjacent trials don't collide at the boundary timestamp. - Compute
path_progress(position_bins, env, start_bins=..., goal_bins=..., metric="geodesic"). Use geodesic on track graphs so distances respect the topology rather than cutting across the linearized 1D gaps.
path_progress reads start/goal per timepoint, so multi-well
shuttling, free choice, or task-cued goals are all just different ways
of populating the two arrays.
See also¶
19_real_data_bandit_task.ipynb- place-field analysis on this same dataset14_behavioral_segmentation.ipynb- alternative trial segmentation toolspath_progressdocstring (src/neurospatial/behavior/navigation.py)