Refactor PathStep to represent edges with source/destination and before/after state

- Add callback and interval parameters to find_stochastic_path() for adaptive weights
- Add get_influence() method to compute weighted score contribution per factor
- Rename graph_node/output_chord to source_node/destination_node/source_chord/destination_chord
- Rename voice_stay_count to sustain_count_before/after
- Rename node_visit_counts to last_visited_count_before/after
- Remove redundant internal state from Path - derive from steps
- Each PathStep now fully self-contained with before/after state
This commit is contained in:
Michael Winter 2026-03-16 14:00:10 +01:00
parent c682a1df02
commit 1926930c3d
2 changed files with 129 additions and 48 deletions

View file

@ -7,7 +7,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import networkx as nx import networkx as nx
from random import choices, seed from random import choices, seed
from typing import Iterator from typing import Callable, Iterator
from .path import Path from .path import Path
@ -34,6 +34,8 @@ class PathFinder:
start_chord: "Chord | None" = None, start_chord: "Chord | None" = None,
max_length: int = 100, max_length: int = 100,
weights_config: dict | None = None, weights_config: dict | None = None,
callback: Callable[[int, Path, dict], None] | None = None,
interval: int = 1,
) -> Path: ) -> Path:
"""Find a stochastic path through the graph. """Find a stochastic path through the graph.
@ -55,6 +57,7 @@ class PathFinder:
) )
graph_node = original_chord graph_node = original_chord
step_num = 0
for _ in range(max_length): for _ in range(max_length):
out_edges = list(self.graph.out_edges(graph_node, data=True)) out_edges = list(self.graph.out_edges(graph_node, data=True))
@ -62,15 +65,25 @@ class PathFinder:
if not out_edges: if not out_edges:
break break
# Derive state from last step (or initialize fresh for step 0)
if path_obj.steps:
last_step = path_obj.steps[-1]
voice_stay_count = last_step.sustain_count_after
node_visit_counts = last_step.last_visited_count_after
else:
# First step - derive from path object's current state
voice_stay_count = tuple(0 for _ in range(len(path_obj._voice_map)))
node_visit_counts = {node: 0 for node in set(self.graph.nodes())}
# Build candidates with raw scores # Build candidates with raw scores
candidates = self._build_candidates( candidates = self._build_candidates(
out_edges, out_edges,
path_obj.output_chords, path_obj.output_chords,
weights_config, weights_config,
tuple(path_obj._voice_stay_count), voice_stay_count,
path_obj.graph_chords, path_obj.graph_chords,
path_obj._cumulative_trans, path_obj._cumulative_trans,
path_obj._node_visit_counts, node_visit_counts,
) )
# Compute weights from raw scores # Compute weights from raw scores
@ -88,13 +101,17 @@ class PathFinder:
# Use path.step() to handle all voice-leading and state updates # Use path.step() to handle all voice-leading and state updates
path_obj.step( path_obj.step(
graph_node=chosen.graph_node, edge=chosen.edge,
edge_data=chosen.edge[2],
candidates=candidates, candidates=candidates,
chosen_scores=chosen.scores, chosen_scores=chosen.scores,
) )
graph_node = chosen.graph_node graph_node = chosen.graph_node
step_num += 1
# Invoke callback if configured
if callback is not None and step_num % interval == 0:
callback(step_num, path_obj, weights_config)
return path_obj return path_obj

View file

@ -16,16 +16,20 @@ if TYPE_CHECKING:
@dataclass @dataclass
class PathStep: class PathStep:
"""Stores data for a single step in the path.""" """Stores data for a single step (edge) in the path."""
graph_node: Chord source_node: Chord
output_chord: Chord destination_node: Chord
source_chord: Chord
destination_chord: Chord
transposition: Pitch | None = None transposition: Pitch | None = None
movements: dict[int, int] = field(default_factory=dict) movements: dict[int, int] = field(default_factory=dict)
scores: dict[str, float] = field(default_factory=dict) scores: dict[str, float] = field(default_factory=dict)
candidates: list["Candidate"] = field(default_factory=list) candidates: list["Candidate"] = field(default_factory=list)
node_visit_counts: dict | None = None last_visited_count_before: dict | None = None
voice_stay_count: tuple[int, ...] | None = None last_visited_count_after: dict | None = None
sustain_count_before: tuple[int, ...] | None = None
sustain_count_after: tuple[int, ...] | None = None
class Path: class Path:
@ -38,36 +42,57 @@ class Path:
self.steps: list[PathStep] = [] self.steps: list[PathStep] = []
self.weights_config = weights_config if weights_config is not None else {} self.weights_config = weights_config if weights_config is not None else {}
# State for tracking # State needed for step computation
self._node_visit_counts: dict = {}
self._voice_stay_count: list[int] = []
self._voice_map: list[int] = [] # which voice is at each position self._voice_map: list[int] = [] # which voice is at each position
self._cumulative_trans: Pitch | None = None # cumulative transposition self._cumulative_trans: Pitch | None = None # cumulative transposition
self._graph_nodes: set = set() # all graph nodes for visit tracking
self._num_voices: int = 0 # number of voices
def init_state( def init_state(
self, graph_nodes: set, num_voices: int, initial_chord: Chord self, graph_nodes: set, num_voices: int, initial_chord: Chord
) -> None: ) -> None:
"""Initialize state after graph is known.""" """Initialize state after graph is known."""
self._node_visit_counts = {node: 0 for node in graph_nodes} self._graph_nodes = graph_nodes
self._node_visit_counts[initial_chord] = 0 self._num_voices = num_voices
self._voice_stay_count = [0] * num_voices
self._voice_map = list(range(num_voices)) # voice i at position i self._voice_map = list(range(num_voices)) # voice i at position i
dims = initial_chord.dims dims = initial_chord.dims
self._cumulative_trans = Pitch(tuple(0 for _ in range(len(dims))), dims) self._cumulative_trans = Pitch(tuple(0 for _ in range(len(dims))), dims)
def _get_last_visited_counts(self) -> dict:
"""Get last visited counts from the last step, or initialize fresh."""
if self.steps:
last_step = self.steps[-1]
if last_step.last_visited_count_after is not None:
return dict(last_step.last_visited_count_after)
# Initialize fresh: all nodes start at 0 (except initial which we set to 0 explicitly)
return {node: 0 for node in self._graph_nodes}
def _get_sustain_counts(self) -> tuple:
"""Get sustain counts from the last step, or initialize fresh."""
if self.steps:
last_step = self.steps[-1]
if last_step.sustain_count_after is not None:
return last_step.sustain_count_after
# Initialize fresh: all voices start at 0
return tuple(0 for _ in range(self._num_voices))
def step( def step(
self, self,
graph_node: "Chord", edge: tuple,
edge_data: dict,
candidates: list["Candidate"], candidates: list["Candidate"],
chosen_scores: dict[str, float] | None = None, chosen_scores: dict[str, float] | None = None,
) -> PathStep: ) -> PathStep:
"""Process a step: update state, compute output, return step. """Process a step: update state, compute output, return step.
Takes graph_node and edge_data, handles all voice-leading internally. Takes edge (source_node, destination_node, edge_data), handles all voice-leading internally.
""" """
# Get edge information source_node = edge[0]
destination_node = edge[1]
edge_data = edge[2]
trans = edge_data.get("transposition") trans = edge_data.get("transposition")
movement = edge_data.get("movements", {}) movement = edge_data.get("movements", {})
@ -75,8 +100,8 @@ class Path:
if trans is not None: if trans is not None:
self._cumulative_trans = self._cumulative_trans.transpose(trans) self._cumulative_trans = self._cumulative_trans.transpose(trans)
# Transpose the graph node # Transpose the destination node
transposed = graph_node.transpose(self._cumulative_trans) transposed = destination_node.transpose(self._cumulative_trans)
# Update voice map based on movement # Update voice map based on movement
new_voice_map = [None] * len(self._voice_map) new_voice_map = [None] * len(self._voice_map)
@ -88,51 +113,58 @@ class Path:
reordered_pitches = tuple( reordered_pitches = tuple(
transposed.pitches[self._voice_map[i]] for i in range(len(self._voice_map)) transposed.pitches[self._voice_map[i]] for i in range(len(self._voice_map))
) )
output_chord = Chord(reordered_pitches, graph_node.dims) destination_chord = Chord(reordered_pitches, destination_node.dims)
# Get previous output chord # Get previous output chord
prev_output_chord = self.output_chords[-1] source_chord = self.output_chords[-1]
# Increment all node visit counts # Get BEFORE state from last step (or initialize fresh)
for node in self._node_visit_counts: last_visited_before = self._get_last_visited_counts()
self._node_visit_counts[node] += 1 sustain_before = self._get_sustain_counts()
# Update voice stay counts (matching master: compare position i with position i) # Compute AFTER state
for voice_idx in range(len(self._voice_stay_count)): last_visited_after = dict(last_visited_before)
curr_cents = prev_output_chord.pitches[voice_idx].to_cents() for node in last_visited_after:
next_cents = output_chord.pitches[voice_idx].to_cents() last_visited_after[node] += 1
last_visited_after[destination_node] = 0
sustain_after = list(sustain_before)
for voice_idx in range(len(sustain_after)):
curr_cents = source_chord.pitches[voice_idx].to_cents()
next_cents = destination_chord.pitches[voice_idx].to_cents()
if curr_cents == next_cents: if curr_cents == next_cents:
self._voice_stay_count[voice_idx] += 1 sustain_after[voice_idx] += 1
else: else:
self._voice_stay_count[voice_idx] = 0 sustain_after[voice_idx] = 0
# Create step with current state # Create step with before and after state
step = PathStep( step = PathStep(
graph_node=graph_node, source_node=source_node,
output_chord=output_chord, destination_node=destination_node,
source_chord=source_chord,
destination_chord=destination_chord,
transposition=trans, transposition=trans,
movements=movement, movements=movement,
scores=chosen_scores if chosen_scores is not None else {}, scores=chosen_scores if chosen_scores is not None else {},
candidates=candidates, candidates=candidates,
node_visit_counts=dict(self._node_visit_counts), last_visited_count_before=last_visited_before,
voice_stay_count=tuple(self._voice_stay_count), last_visited_count_after=last_visited_after,
sustain_count_before=sustain_before,
sustain_count_after=tuple(sustain_after),
) )
# Reset visit count for this node
self._node_visit_counts[graph_node] = 0
self.steps.append(step) self.steps.append(step)
return step return step
@property @property
def graph_chords(self) -> list[Chord]: def graph_chords(self) -> list[Chord]:
"""Get list of graph nodes (original chords).""" """Get list of destination graph nodes."""
return [self.initial_chord] + [step.graph_node for step in self.steps] return [self.initial_chord] + [step.destination_node for step in self.steps]
@property @property
def output_chords(self) -> list[Chord]: def output_chords(self) -> list[Chord]:
"""Get list of output chords (transposed).""" """Get list of destination chords (transposed)."""
return [self.initial_chord] + [step.output_chord for step in self.steps] return [self.initial_chord] + [step.destination_chord for step in self.steps]
def __len__(self) -> int: def __len__(self) -> int:
"""Total number of chords in path.""" """Total number of chords in path."""
@ -142,6 +174,38 @@ class Path:
"""Iterate over output chords.""" """Iterate over output chords."""
return iter(self.output_chords) return iter(self.output_chords)
def __getitem__(self, index: int) -> Chord: def get_influence(self, weights: dict[str, Any]) -> dict[str, float]:
"""Get output chord by index.""" """Compute weighted score contribution per factor for chosen candidates.
return self.output_chords[index]
Returns a dict mapping factor name to accumulated influence (weight * score)
for all steps in the path.
"""
influence = {
"melodic": 0.0,
"contrary_motion": 0.0,
"dca_hamiltonian": 0.0,
"dca_voice_movement": 0.0,
"target_range": 0.0,
}
for step in self.steps:
scores = step.scores
w_melodic = weights.get("weight_melodic", 1)
w_contrary = weights.get("weight_contrary_motion", 0)
w_hamiltonian = weights.get("weight_dca_hamiltonian", 1)
w_dca = weights.get("weight_dca_voice_movement", 1)
w_target = weights.get("weight_target_range", 1)
influence["melodic"] += scores.get("melodic_threshold", 0) * w_melodic
influence["contrary_motion"] += (
scores.get("contrary_motion", 0) * w_contrary
)
influence["dca_hamiltonian"] += (
scores.get("dca_hamiltonian", 0) * w_hamiltonian
)
influence["dca_voice_movement"] += (
scores.get("dca_voice_movement", 0) * w_dca
)
influence["target_range"] += scores.get("target_range", 0) * w_target
return influence