diff --git a/src/smithy/rcv.py b/src/smithy/rcv.py index 07c19dc..1f5a07a 100644 --- a/src/smithy/rcv.py +++ b/src/smithy/rcv.py @@ -3,7 +3,7 @@ import rustworkx as rwx from itertools import combinations -def _pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph: +def pmg_from_rcv_polars(ballots: pl.DataFrame) -> rwx.PyDiGraph: """ Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. @@ -54,7 +54,7 @@ def _pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph: return pmg -def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph: +def pmg_from_rcv_numpy(ballots: pl.DataFrame) -> rwx.PyDiGraph: """ Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. @@ -97,4 +97,12 @@ def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph: elif b_wins > a_wins: pmg.add_edge(nodes[b], nodes[a], int(b_wins - a_wins)) - return pmg \ No newline at end of file + return pmg + +def pmg_from_rcv(ballots: pl.DataFrame, method="polars") -> rwx.PyDiGraph: + if method == "polars": + return pmg_from_rcv_polars(ballots) + elif method == "numpy": + return pmg_from_rcv_numpy(ballots) + else: + raise NotImplementedError(f"`pmg_from_rcv` method={method} not implemented.") \ No newline at end of file