switch graph-building to numpy, add ballot cacheing

This commit is contained in:
Thomas (Tom) C. Gorordo 2026-05-25 07:31:26 -07:00
parent 86e880623f
commit 22765fe052
Signed by: tgorordo
GPG key ID: 0CBED22BB0D94490

View file

@ -3,7 +3,7 @@ import rustworkx as rwx
from itertools import combinations from itertools import combinations
def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph: def _pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph:
""" """
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
@ -51,3 +51,48 @@ def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph:
pmg.add_edge(nodes[b], nodes[a], b_wins - a_wins) pmg.add_edge(nodes[b], nodes[a], b_wins - a_wins)
return pmg return pmg
def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph:
"""
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
parameters
---
rcv_ballots : pl.DataFrame
A Polars DataFrame representing ballots. Each column is a candidate and each
row is is a voter's ranking of the candidates. Lower numbers indicate higher
preference (1 = top-choice).
returns
---
nodes: dict[str, int]
A dictionary of candidate names to associated node ids.
pwm_graph: rwx.PyDiGraph
A pairwise majority winner graph whose nodes correspond to candidates and
(directed) edges show which candidates they beat pairwise.
"""
candidates = ballots.columns
pmg = rwx.PyDiGraph()
nodes = {c: pmg.add_node(c) for c in candidates}
compressed = ballots.group_by(ballots.columns).len().rename({"len": "count"})
counts = compressed["count"].to_numpy()
arr = compressed.drop("count").to_numpy()
results = ((arr[:, :, None] < arr[:, None, :]) * counts[:, None, None]).sum(axis=0)
for i, a in enumerate(candidates):
for j in range(i + 1, len(candidates)):
b = candidates[j]
a_wins = results[i, j]
b_wins = results[j, i]
if a_wins > b_wins:
pmg.add_edge(nodes[a], nodes[b], int(a_wins - b_wins))
elif b_wins > a_wins:
pmg.add_edge(nodes[b], nodes[a], int(b_wins - a_wins))
return pmg