compact_sets/src/path.py

77 lines
2.3 KiB
Python
Raw Normal View History

#!/usr/bin/env python
"""
Path and PathStep classes for storing path state from PathFinder.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from .pitch import Pitch
from .chord import Chord
@dataclass
class PathStep:
"""Stores data for a single step in the path."""
graph_node: Chord
output_chord: Chord
transposition: Pitch | None = None
movements: dict[int, int] = field(default_factory=dict)
scores: dict[str, float] = field(default_factory=dict)
candidates: list[dict[str, float]] = field(default_factory=list)
class Path:
"""Stores the complete state of a generated path."""
def __init__(
self, initial_chord: Chord, weights_config: dict[str, Any] | None = None
):
self.initial_chord = initial_chord
self.steps: list[PathStep] = []
self.weights_config = weights_config if weights_config is not None else {}
def add_step(
self,
graph_node: Chord,
output_chord: Chord,
transposition: Pitch | None = None,
movements: dict[int, int] | None = None,
scores: dict[str, float] | None = None,
candidates: list[dict[str, float]] | None = None,
) -> None:
"""Add a step to the path."""
step = PathStep(
graph_node=graph_node,
output_chord=output_chord,
transposition=transposition,
movements=movements if movements is not None else {},
scores=scores if scores is not None else {},
candidates=candidates if candidates is not None else [],
)
self.steps.append(step)
@property
def graph_chords(self) -> list[Chord]:
"""Get list of graph nodes (original chords)."""
return [self.initial_chord] + [step.graph_node for step in self.steps]
@property
def output_chords(self) -> list[Chord]:
"""Get list of output chords (transposed)."""
return [self.initial_chord] + [step.output_chord for step in self.steps]
def __len__(self) -> int:
"""Total number of chords in path."""
return len(self.steps) + 1
def __iter__(self):
"""Iterate over output chords."""
return iter(self.output_chords)
def __getitem__(self, index: int) -> Chord:
"""Get output chord by index."""
return self.output_chords[index]