Refactor: Unify Candidate and PathStep, fix DCA Hamiltonian
- Remove Candidate class, use PathStep for both hypothetical and actual steps - Simplify Path.step() to accept a PathStep - Fix DCA Hamiltonian to return visit_count directly instead of normalized score - Tests pass and DCA properly discriminates
This commit is contained in:
parent
400f970858
commit
861d012a95
72
src/path.py
72
src/path.py
|
|
@ -5,14 +5,11 @@ Path and PathStep classes for storing path state from PathFinder.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import Any
|
||||||
|
|
||||||
from .pitch import Pitch
|
from .pitch import Pitch
|
||||||
from .chord import Chord
|
from .chord import Chord
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .pathfinder import Candidate
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PathStep:
|
class PathStep:
|
||||||
|
|
@ -25,7 +22,7 @@ class PathStep:
|
||||||
transposition: Pitch | None = None
|
transposition: Pitch | None = None
|
||||||
movements: dict[int, int] = field(default_factory=dict)
|
movements: dict[int, int] = field(default_factory=dict)
|
||||||
scores: dict[str, float] = field(default_factory=dict)
|
scores: dict[str, float] = field(default_factory=dict)
|
||||||
candidates: list["Candidate"] = field(default_factory=list)
|
weight: float = 0.0 # computed later by _compute_weights
|
||||||
last_visited_count_before: dict | None = None
|
last_visited_count_before: dict | None = None
|
||||||
last_visited_count_after: dict | None = None
|
last_visited_count_after: dict | None = None
|
||||||
sustain_count_before: tuple[int, ...] | None = None
|
sustain_count_before: tuple[int, ...] | None = None
|
||||||
|
|
@ -79,45 +76,24 @@ class Path:
|
||||||
# Initialize fresh: all voices start at 0
|
# Initialize fresh: all voices start at 0
|
||||||
return tuple(0 for _ in range(self._num_voices))
|
return tuple(0 for _ in range(self._num_voices))
|
||||||
|
|
||||||
def step(
|
def step(self, step: PathStep) -> PathStep:
|
||||||
self,
|
"""Add a completed step to the path.
|
||||||
edge: tuple,
|
|
||||||
candidates: list["Candidate"],
|
|
||||||
chosen_scores: dict[str, float] | None = None,
|
|
||||||
) -> PathStep:
|
|
||||||
"""Process a step: update state, compute output, return step.
|
|
||||||
|
|
||||||
Takes edge (source_node, destination_node, edge_data), handles all voice-leading internally.
|
Takes a PathStep (computed as a hypothetical step), updates internal state,
|
||||||
|
and adds it to the path.
|
||||||
"""
|
"""
|
||||||
source_node = edge[0]
|
|
||||||
destination_node = edge[1]
|
|
||||||
edge_data = edge[2]
|
|
||||||
|
|
||||||
trans = edge_data.get("transposition")
|
|
||||||
movement = edge_data.get("movements", {})
|
|
||||||
|
|
||||||
# Update cumulative transposition
|
# Update cumulative transposition
|
||||||
if trans is not None:
|
if step.transposition is not None:
|
||||||
self._cumulative_trans = self._cumulative_trans.transpose(trans)
|
self._cumulative_trans = self._cumulative_trans.transpose(
|
||||||
|
step.transposition
|
||||||
# Transpose the destination node
|
)
|
||||||
transposed = destination_node.transpose(self._cumulative_trans)
|
|
||||||
|
|
||||||
# Update voice map based on movement
|
# Update voice map based on movement
|
||||||
new_voice_map = [None] * len(self._voice_map)
|
new_voice_map = [None] * len(self._voice_map)
|
||||||
for src_idx, dest_idx in movement.items():
|
for src_idx, dest_idx in step.movements.items():
|
||||||
new_voice_map[dest_idx] = self._voice_map[src_idx]
|
new_voice_map[dest_idx] = self._voice_map[src_idx]
|
||||||
self._voice_map = new_voice_map
|
self._voice_map = new_voice_map
|
||||||
|
|
||||||
# Reorder pitches according to voice map
|
|
||||||
reordered_pitches = tuple(
|
|
||||||
transposed.pitches[self._voice_map[i]] for i in range(len(self._voice_map))
|
|
||||||
)
|
|
||||||
destination_chord = Chord(reordered_pitches, destination_node.dims)
|
|
||||||
|
|
||||||
# Get previous output chord
|
|
||||||
source_chord = self.output_chords[-1]
|
|
||||||
|
|
||||||
# Get BEFORE state from last step (or initialize fresh)
|
# Get BEFORE state from last step (or initialize fresh)
|
||||||
last_visited_before = self._get_last_visited_counts()
|
last_visited_before = self._get_last_visited_counts()
|
||||||
sustain_before = self._get_sustain_counts()
|
sustain_before = self._get_sustain_counts()
|
||||||
|
|
@ -126,32 +102,22 @@ class Path:
|
||||||
last_visited_after = dict(last_visited_before)
|
last_visited_after = dict(last_visited_before)
|
||||||
for node in last_visited_after:
|
for node in last_visited_after:
|
||||||
last_visited_after[node] += 1
|
last_visited_after[node] += 1
|
||||||
last_visited_after[destination_node] = 0
|
last_visited_after[step.destination_node] = 0
|
||||||
|
|
||||||
sustain_after = list(sustain_before)
|
sustain_after = list(sustain_before)
|
||||||
for voice_idx in range(len(sustain_after)):
|
for voice_idx in range(len(sustain_after)):
|
||||||
curr_cents = source_chord.pitches[voice_idx].to_cents()
|
curr_cents = step.source_chord.pitches[voice_idx].to_cents()
|
||||||
next_cents = destination_chord.pitches[voice_idx].to_cents()
|
next_cents = step.destination_chord.pitches[voice_idx].to_cents()
|
||||||
if curr_cents == next_cents:
|
if curr_cents == next_cents:
|
||||||
sustain_after[voice_idx] += 1
|
sustain_after[voice_idx] += 1
|
||||||
else:
|
else:
|
||||||
sustain_after[voice_idx] = 0
|
sustain_after[voice_idx] = 0
|
||||||
|
|
||||||
# Create step with before and after state
|
# Update step with computed state
|
||||||
step = PathStep(
|
step.last_visited_count_before = last_visited_before
|
||||||
source_node=source_node,
|
step.last_visited_count_after = last_visited_after
|
||||||
destination_node=destination_node,
|
step.sustain_count_before = sustain_before
|
||||||
source_chord=source_chord,
|
step.sustain_count_after = tuple(sustain_after)
|
||||||
destination_chord=destination_chord,
|
|
||||||
transposition=trans,
|
|
||||||
movements=movement,
|
|
||||||
scores=chosen_scores if chosen_scores is not None else {},
|
|
||||||
candidates=candidates,
|
|
||||||
last_visited_count_before=last_visited_before,
|
|
||||||
last_visited_count_after=last_visited_after,
|
|
||||||
sustain_count_before=sustain_before,
|
|
||||||
sustain_count_after=tuple(sustain_after),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.steps.append(step)
|
self.steps.append(step)
|
||||||
return step
|
return step
|
||||||
|
|
|
||||||
|
|
@ -4,23 +4,12 @@ PathFinder - finds paths through voice leading graphs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from dataclasses import dataclass
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from random import choices, seed
|
from random import choices, seed
|
||||||
from typing import Callable, Iterator
|
from typing import Callable
|
||||||
|
|
||||||
from .path import Path
|
from .chord import Chord
|
||||||
|
from .path import Path, PathStep
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Candidate:
|
|
||||||
"""A candidate edge with raw factor scores."""
|
|
||||||
|
|
||||||
edge: tuple
|
|
||||||
edge_index: int
|
|
||||||
graph_node: "Chord"
|
|
||||||
scores: dict[str, float]
|
|
||||||
weight: float = 0.0 # computed later by _compute_weights
|
|
||||||
|
|
||||||
|
|
||||||
class PathFinder:
|
class PathFinder:
|
||||||
|
|
@ -100,13 +89,9 @@ class PathFinder:
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Use path.step() to handle all voice-leading and state updates
|
# Use path.step() to handle all voice-leading and state updates
|
||||||
path_obj.step(
|
path_obj.step(chosen)
|
||||||
edge=chosen.edge,
|
|
||||||
candidates=candidates,
|
|
||||||
chosen_scores=chosen.scores,
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_node = chosen.graph_node
|
graph_node = chosen.destination_node
|
||||||
step_num += 1
|
step_num += 1
|
||||||
|
|
||||||
# Invoke callback if configured
|
# Invoke callback if configured
|
||||||
|
|
@ -124,15 +109,43 @@ class PathFinder:
|
||||||
graph_path: list["Chord"] | None,
|
graph_path: list["Chord"] | None,
|
||||||
cumulative_trans: "Pitch | None",
|
cumulative_trans: "Pitch | None",
|
||||||
node_visit_counts: dict | None,
|
node_visit_counts: dict | None,
|
||||||
) -> list["Candidate"]:
|
) -> list[PathStep]:
|
||||||
"""Build candidates with raw factor scores only."""
|
"""Build hypothetical path steps with raw factor scores."""
|
||||||
if not out_edges:
|
if not out_edges:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return []
|
||||||
|
|
||||||
|
source_chord = path[-1]
|
||||||
candidates = []
|
candidates = []
|
||||||
for i, edge in enumerate(out_edges):
|
for i, edge in enumerate(out_edges):
|
||||||
|
source_node = edge[0]
|
||||||
|
destination_node = edge[1]
|
||||||
edge_data = edge[2]
|
edge_data = edge[2]
|
||||||
|
|
||||||
|
trans = edge_data.get("transposition")
|
||||||
|
movement = edge_data.get("movements", {})
|
||||||
|
|
||||||
|
# Transpose destination node
|
||||||
|
if trans is not None and cumulative_trans is not None:
|
||||||
|
transposed = destination_node.transpose(
|
||||||
|
cumulative_trans.transpose(trans)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transposed = destination_node
|
||||||
|
|
||||||
|
# Apply voice map
|
||||||
|
voice_map = list(range(len(source_chord.pitches)))
|
||||||
|
new_voice_map = [None] * len(voice_map)
|
||||||
|
for src_idx, dest_idx in movement.items():
|
||||||
|
new_voice_map[dest_idx] = voice_map[src_idx]
|
||||||
|
|
||||||
|
reordered_pitches = tuple(
|
||||||
|
transposed.pitches[new_voice_map[i]] for i in range(len(new_voice_map))
|
||||||
|
)
|
||||||
|
destination_chord = Chord(reordered_pitches, destination_node.dims)
|
||||||
|
|
||||||
# All factors - always compute verbatim
|
# All factors - always compute verbatim
|
||||||
direct_tuning = self._factor_direct_tuning(edge_data, config)
|
direct_tuning = self._factor_direct_tuning(edge_data, config)
|
||||||
voice_crossing = self._factor_voice_crossing(edge_data, config)
|
voice_crossing = self._factor_voice_crossing(edge_data, config)
|
||||||
|
|
@ -154,18 +167,27 @@ class PathFinder:
|
||||||
"target_range": target,
|
"target_range": target,
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates.append(Candidate(edge, i, edge[1], scores, 0.0))
|
step = PathStep(
|
||||||
|
source_node=source_node,
|
||||||
|
destination_node=destination_node,
|
||||||
|
source_chord=source_chord,
|
||||||
|
destination_chord=destination_chord,
|
||||||
|
transposition=trans,
|
||||||
|
movements=movement,
|
||||||
|
scores=scores,
|
||||||
|
)
|
||||||
|
candidates.append(step)
|
||||||
|
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
def _compute_weights(
|
def _compute_weights(
|
||||||
self,
|
self,
|
||||||
candidates: list["Candidate"],
|
candidates: list[PathStep],
|
||||||
config: dict,
|
config: dict,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
"""Compute weights from raw scores for all candidates.
|
"""Compute weights from raw scores for all candidates.
|
||||||
|
|
||||||
Returns a list of weights, and updates each candidate's weight field.
|
Returns a list of weights, and updates each step's weight field.
|
||||||
"""
|
"""
|
||||||
if not candidates:
|
if not candidates:
|
||||||
return []
|
return []
|
||||||
|
|
@ -194,20 +216,20 @@ class PathFinder:
|
||||||
|
|
||||||
# Calculate weights for each candidate
|
# Calculate weights for each candidate
|
||||||
weights = []
|
weights = []
|
||||||
for i, candidate in enumerate(candidates):
|
for i, step in enumerate(candidates):
|
||||||
scores = candidate.scores
|
scores = step.scores
|
||||||
w = 1.0
|
w = 1.0
|
||||||
|
|
||||||
# Hard factors (multiplicative - eliminates if 0)
|
# Hard factors (multiplicative - eliminates if 0)
|
||||||
w *= scores.get("direct_tuning", 0)
|
w *= scores.get("direct_tuning", 0)
|
||||||
if w == 0:
|
if w == 0:
|
||||||
candidate.weight = 0.0
|
step.weight = 0.0
|
||||||
weights.append(0.0)
|
weights.append(0.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
w *= scores.get("voice_crossing", 0)
|
w *= scores.get("voice_crossing", 0)
|
||||||
if w == 0:
|
if w == 0:
|
||||||
candidate.weight = 0.0
|
step.weight = 0.0
|
||||||
weights.append(0.0)
|
weights.append(0.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -223,7 +245,7 @@ class PathFinder:
|
||||||
if target_norm:
|
if target_norm:
|
||||||
w += target_norm[i] * config.get("weight_target_range", 1)
|
w += target_norm[i] * config.get("weight_target_range", 1)
|
||||||
|
|
||||||
candidate.weight = w
|
step.weight = w
|
||||||
weights.append(w)
|
weights.append(w)
|
||||||
|
|
||||||
return weights
|
return weights
|
||||||
|
|
@ -353,29 +375,21 @@ class PathFinder:
|
||||||
def _factor_dca_hamiltonian(
|
def _factor_dca_hamiltonian(
|
||||||
self, edge: tuple, node_visit_counts: dict | None, config: dict
|
self, edge: tuple, node_visit_counts: dict | None, config: dict
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Returns probability based on how long since node was last visited.
|
"""Returns score based on how long since node was last visited.
|
||||||
|
|
||||||
DCA Hamiltonian: longer since visited = higher probability.
|
DCA Hamiltonian: longer since visited = higher score.
|
||||||
Similar to DCA voice movement but for graph nodes.
|
|
||||||
"""
|
"""
|
||||||
if config.get("weight_dca_hamiltonian", 1) == 0:
|
if config.get("weight_dca_hamiltonian", 1) == 0:
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
if node_visit_counts is None:
|
if node_visit_counts is None:
|
||||||
return 1.0
|
return 0.0
|
||||||
|
|
||||||
destination = edge[1]
|
destination = edge[1]
|
||||||
if destination not in node_visit_counts:
|
visit_count = node_visit_counts.get(destination, 0)
|
||||||
return 1.0
|
|
||||||
|
|
||||||
visit_count = node_visit_counts[destination]
|
# Return the visit count - higher is better (more steps since last visit)
|
||||||
max_count = max(node_visit_counts.values()) if node_visit_counts else 0
|
return float(visit_count)
|
||||||
|
|
||||||
if max_count == 0:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Normalize by max squared - gives stronger discrimination
|
|
||||||
return visit_count / (max_count**2)
|
|
||||||
|
|
||||||
def _factor_dca_voice_movement(
|
def _factor_dca_voice_movement(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue