Fix uniform symdiff to fallback to other symdiff groups when no valid candidates
This commit is contained in:
parent
863423ca7a
commit
7dd7f23611
|
|
@ -68,6 +68,12 @@ class PathFinder:
|
||||||
|
|
||||||
chosen_symdiff = random.choice(list(edges_by_symdiff.keys()))
|
chosen_symdiff = random.choice(list(edges_by_symdiff.keys()))
|
||||||
out_edges = edges_by_symdiff[chosen_symdiff]
|
out_edges = edges_by_symdiff[chosen_symdiff]
|
||||||
|
# Use uniform weights for this symdiff group
|
||||||
|
uniform_weight = 1.0
|
||||||
|
out_edges_with_weight = [
|
||||||
|
(e[0], e[1], {**e[2], "weight": uniform_weight}) for e in out_edges
|
||||||
|
]
|
||||||
|
out_edges = out_edges_with_weight
|
||||||
|
|
||||||
# Build candidates using Path's state and factor methods
|
# Build candidates using Path's state and factor methods
|
||||||
candidates = path_obj.get_candidates(out_edges, path_obj.output_chords)
|
candidates = path_obj.get_candidates(out_edges, path_obj.output_chords)
|
||||||
|
|
@ -77,6 +83,29 @@ class PathFinder:
|
||||||
|
|
||||||
# Filter out candidates with zero weight
|
# Filter out candidates with zero weight
|
||||||
valid_candidates = [c for c in candidates if c.weight > 0]
|
valid_candidates = [c for c in candidates if c.weight > 0]
|
||||||
|
|
||||||
|
# If uniform_symdiff and no valid candidates, try other symdiff groups
|
||||||
|
if (
|
||||||
|
not valid_candidates
|
||||||
|
and weights_config.get("uniform_symdiff")
|
||||||
|
and edges_by_symdiff
|
||||||
|
):
|
||||||
|
remaining_symdiffs = [
|
||||||
|
s for s in edges_by_symdiff.keys() if s != chosen_symdiff
|
||||||
|
]
|
||||||
|
for fallback_symdiff in remaining_symdiffs:
|
||||||
|
out_edges = edges_by_symdiff[fallback_symdiff]
|
||||||
|
out_edges = [
|
||||||
|
(e[0], e[1], {**e[2], "weight": 1.0}) for e in out_edges
|
||||||
|
]
|
||||||
|
candidates = path_obj.get_candidates(
|
||||||
|
out_edges, path_obj.output_chords
|
||||||
|
)
|
||||||
|
path_obj.compute_weights(candidates, weights_config)
|
||||||
|
valid_candidates = [c for c in candidates if c.weight > 0]
|
||||||
|
if valid_candidates:
|
||||||
|
break
|
||||||
|
|
||||||
if not valid_candidates:
|
if not valid_candidates:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue