compact_sets/src/pathfinder.py

178 lines
6.1 KiB
Python
Raw Normal View History

#!/usr/bin/env python
"""
PathFinder - finds paths through voice leading graphs.
"""
from __future__ import annotations
import networkx as nx
from random import choices
from typing import Callable
from .chord import Chord
from .path import Path, PathStep
class PathFinder:
"""Finds paths through voice leading graphs."""
def __init__(self, graph: nx.MultiDiGraph):
self.graph = graph
def find_stochastic_path(
self,
start_chord: "Chord | None" = None,
max_length: int = 100,
weights_config: dict | None = None,
callback: Callable[[int, Path, dict], None] | None = None,
interval: int = 1,
) -> Path:
"""Find a stochastic path through the graph.
Returns:
Path object containing output chords, graph chords, and metadata
"""
if weights_config is None:
weights_config = self._default_weights_config()
chord = self._initialize_chords(start_chord)
if not chord or chord[0] is None or len(self.graph.nodes()) == 0:
return Path(chord[0] if chord else None, weights_config)
original_chord = chord[0]
path_obj = Path(original_chord, weights_config)
path_obj.init_state(
set(self.graph.nodes()), len(original_chord.pitches), original_chord
)
graph_node = original_chord
step_num = 0
for _ in range(max_length):
out_edges = list(self.graph.out_edges(graph_node, data=True))
if not out_edges:
break
# Group edges by symdiff for uniform selection
edges_by_symdiff = {}
for edge in out_edges:
symdiff = edge[2].get("symdiff", 0)
if symdiff not in edges_by_symdiff:
edges_by_symdiff[symdiff] = []
edges_by_symdiff[symdiff].append(edge)
# If uniform_symdiff, pick symdiff uniformly then use those edges
if weights_config.get("uniform_symdiff") and edges_by_symdiff:
import random
chosen_symdiff = random.choice(list(edges_by_symdiff.keys()))
out_edges = edges_by_symdiff[chosen_symdiff]
# Use uniform weights for this symdiff group
uniform_weight = 1.0
out_edges_with_weight = [
(e[0], e[1], {**e[2], "weight": uniform_weight}) for e in out_edges
]
out_edges = out_edges_with_weight
# Build candidates using Path's state and factor methods
candidates = path_obj.get_candidates(out_edges, path_obj.output_chords)
# Compute weights using Path's method
path_obj.compute_weights(candidates, weights_config)
# Filter out candidates with zero weight
valid_candidates = [c for c in candidates if c.weight > 0]
# If uniform_symdiff and no valid candidates, try other symdiff groups
if (
not valid_candidates
and weights_config.get("uniform_symdiff")
and edges_by_symdiff
):
remaining_symdiffs = [
s for s in edges_by_symdiff.keys() if s != chosen_symdiff
]
for fallback_symdiff in remaining_symdiffs:
out_edges = edges_by_symdiff[fallback_symdiff]
out_edges = [
(e[0], e[1], {**e[2], "weight": 1.0}) for e in out_edges
]
candidates = path_obj.get_candidates(
out_edges, path_obj.output_chords
)
path_obj.compute_weights(candidates, weights_config)
valid_candidates = [c for c in candidates if c.weight > 0]
if valid_candidates:
break
if not valid_candidates:
break
# Select using weighted choice
chosen = choices(
valid_candidates, weights=[c.weight for c in valid_candidates]
)[0]
# Use path.step() to handle all voice-leading and state updates
path_obj.step(chosen)
graph_node = chosen.destination_node
step_num += 1
# Invoke callback if configured
if callback is not None and step_num % interval == 0:
callback(step_num, path_obj, weights_config)
return path_obj
def _initialize_chords(self, start_chord: "Chord | None") -> tuple:
"""Initialize chord sequence."""
if start_chord is not None:
return (start_chord,)
nodes = list(self.graph.nodes())
if nodes:
import random
random.shuffle(nodes)
weights_config = self._default_weights_config()
weights_config["voice_crossing_allowed"] = False
for chord in nodes[:50]:
out_edges = list(self.graph.out_edges(chord, data=True))
if len(out_edges) == 0:
continue
path = Path(chord, weights_config)
path.init_state(set(self.graph.nodes()), len(chord.pitches), chord)
candidates = path.get_candidates(out_edges, [chord])
path.compute_weights(candidates, weights_config)
nonzero = sum(1 for c in candidates if c.weight > 0)
if nonzero > 0:
return (chord,)
return (nodes[0],)
return (None,)
def _default_weights_config(self) -> dict:
"""Default weights configuration."""
return {
"contrary_motion": True,
"direct_tuning": True,
"voice_crossing_allowed": False,
"melodic_threshold_min": 0,
"melodic_threshold_max": 500,
"hamiltonian": True,
"dca": 2.0,
"target_register": False,
"target_register_octaves": 2.0,
}
def is_hamiltonian(self, path: list["Chord"]) -> bool:
"""Check if a path is Hamiltonian (visits all nodes exactly once)."""
return len(path) == len(self.graph.nodes()) and len(set(path)) == len(path)