Fix uniform symdiff to fallback to other symdiff groups when no valid candidates
This commit is contained in:
parent
863423ca7a
commit
7dd7f23611
|
|
@ -68,6 +68,12 @@ class PathFinder:
|
|||
|
||||
chosen_symdiff = random.choice(list(edges_by_symdiff.keys()))
|
||||
out_edges = edges_by_symdiff[chosen_symdiff]
|
||||
# Use uniform weights for this symdiff group
|
||||
uniform_weight = 1.0
|
||||
out_edges_with_weight = [
|
||||
(e[0], e[1], {**e[2], "weight": uniform_weight}) for e in out_edges
|
||||
]
|
||||
out_edges = out_edges_with_weight
|
||||
|
||||
# Build candidates using Path's state and factor methods
|
||||
candidates = path_obj.get_candidates(out_edges, path_obj.output_chords)
|
||||
|
|
@ -77,6 +83,29 @@ class PathFinder:
|
|||
|
||||
# Filter out candidates with zero weight
|
||||
valid_candidates = [c for c in candidates if c.weight > 0]
|
||||
|
||||
# If uniform_symdiff and no valid candidates, try other symdiff groups
|
||||
if (
|
||||
not valid_candidates
|
||||
and weights_config.get("uniform_symdiff")
|
||||
and edges_by_symdiff
|
||||
):
|
||||
remaining_symdiffs = [
|
||||
s for s in edges_by_symdiff.keys() if s != chosen_symdiff
|
||||
]
|
||||
for fallback_symdiff in remaining_symdiffs:
|
||||
out_edges = edges_by_symdiff[fallback_symdiff]
|
||||
out_edges = [
|
||||
(e[0], e[1], {**e[2], "weight": 1.0}) for e in out_edges
|
||||
]
|
||||
candidates = path_obj.get_candidates(
|
||||
out_edges, path_obj.output_chords
|
||||
)
|
||||
path_obj.compute_weights(candidates, weights_config)
|
||||
valid_candidates = [c for c in candidates if c.weight > 0]
|
||||
if valid_candidates:
|
||||
break
|
||||
|
||||
if not valid_candidates:
|
||||
break
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue