#!/usr/bin/env python """ Path and PathStep classes for storing path state from PathFinder. """ from __future__ import annotations from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from .pitch import Pitch from .chord import Chord if TYPE_CHECKING: from .graph import Candidate @dataclass class PathStep: """Stores data for a single step in the path.""" graph_node: Chord output_chord: Chord transposition: Pitch | None = None movements: dict[int, int] = field(default_factory=dict) scores: dict[str, float] = field(default_factory=dict) candidates: list["Candidate"] = field(default_factory=list) node_visit_counts: dict | None = None voice_stay_count: tuple[int, ...] | None = None class Path: """Stores the complete state of a generated path.""" def __init__( self, initial_chord: Chord | None, weights_config: dict[str, Any] | None = None ): self.initial_chord = initial_chord self.steps: list[PathStep] = [] self.weights_config = weights_config if weights_config is not None else {} # State for tracking self._node_visit_counts: dict = {} self._voice_stay_count: list[int] = [] self._voice_map: list[int] = [] # which voice is at each position self._cumulative_trans: Pitch | None = None # cumulative transposition def init_state( self, graph_nodes: set, num_voices: int, initial_chord: Chord ) -> None: """Initialize state after graph is known.""" self._node_visit_counts = {node: 0 for node in graph_nodes} self._node_visit_counts[initial_chord] = 0 self._voice_stay_count = [0] * num_voices self._voice_map = list(range(num_voices)) # voice i at position i dims = initial_chord.dims self._cumulative_trans = Pitch(tuple(0 for _ in range(len(dims))), dims) def step( self, graph_node: "Chord", edge_data: dict, candidates: list["Candidate"], chosen_scores: dict[str, float] | None = None, ) -> PathStep: """Process a step: update state, compute output, return step. Takes graph_node and edge_data, handles all voice-leading internally. """ # Get edge information trans = edge_data.get("transposition") movement = edge_data.get("movements", {}) # Update cumulative transposition if trans is not None: self._cumulative_trans = self._cumulative_trans.transpose(trans) # Transpose the graph node transposed = graph_node.transpose(self._cumulative_trans) # Update voice map based on movement new_voice_map = [None] * len(self._voice_map) for src_idx, dest_idx in movement.items(): new_voice_map[dest_idx] = self._voice_map[src_idx] self._voice_map = new_voice_map # Reorder pitches according to voice map reordered_pitches = tuple( transposed.pitches[self._voice_map[i]] for i in range(len(self._voice_map)) ) output_chord = Chord(reordered_pitches, graph_node.dims) # Get previous output chord prev_output_chord = self.output_chords[-1] # Increment all node visit counts for node in self._node_visit_counts: self._node_visit_counts[node] += 1 # Update voice stay counts (comparing same voice, not position) for voice_idx in range(len(self._voice_stay_count)): # Find which position this voice was at in previous chord prev_voice_pos = None for pos, voice in enumerate(self._voice_map): if voice == voice_idx: prev_voice_pos = pos break # Current position of this voice curr_voice_pos = voice_idx if prev_voice_pos is not None: prev_cents = prev_output_chord.pitches[prev_voice_pos].to_cents() else: prev_cents = None curr_cents = output_chord.pitches[curr_voice_pos].to_cents() if prev_cents is not None and prev_cents == curr_cents: self._voice_stay_count[voice_idx] += 1 else: self._voice_stay_count[voice_idx] = 0 # Create step with current state step = PathStep( graph_node=graph_node, output_chord=output_chord, transposition=trans, movements=movement, scores=chosen_scores if chosen_scores is not None else {}, candidates=candidates, node_visit_counts=dict(self._node_visit_counts), voice_stay_count=tuple(self._voice_stay_count), ) # Reset visit count for this node self._node_visit_counts[graph_node] = 0 self.steps.append(step) return step @property def graph_chords(self) -> list[Chord]: """Get list of graph nodes (original chords).""" return [self.initial_chord] + [step.graph_node for step in self.steps] @property def output_chords(self) -> list[Chord]: """Get list of output chords (transposed).""" return [self.initial_chord] + [step.output_chord for step in self.steps] def __len__(self) -> int: """Total number of chords in path.""" return len(self.steps) + 1 def __iter__(self): """Iterate over output chords.""" return iter(self.output_chords) def __getitem__(self, index: int) -> Chord: """Get output chord by index.""" return self.output_chords[index]