diff --git a/justfile b/justfile index a6f0786..b2dcf7d 100644 --- a/justfile +++ b/justfile @@ -7,9 +7,6 @@ run *args: marimo: uv run marimo --edit -example: - uv run python src/main.py test/test_ballot.csv - format: uv run ruff format src test diff --git a/pyproject.toml b/pyproject.toml index 9ffd889..3aff00a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ authors = [ ] requires-python = ">=3.13" dependencies = [ - "numpy>=2.4.6", "polars>=1.40.1", "rustworkx>=0.17.1", ] diff --git a/requirements.txt b/requirements.txt index 7b23b9d..675f3cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -96,9 +96,7 @@ nbformat==5.10.4 nodeenv==1.10.0 # via pyright numpy==2.4.6 - # via - # smithy (pyproject.toml) - # rustworkx + # via rustworkx openai==2.37.0 # via pydantic-ai-slim opentelemetry-api==1.42.0 diff --git a/src/smithy/__init__.py b/src/smithy/__init__.py index ed1ba6f..96cb85d 100644 --- a/src/smithy/__init__.py +++ b/src/smithy/__init__.py @@ -1,16 +1,17 @@ import polars as pl +import rustworkx as rwx from itertools import combinations -from .rcv import pairmaj_from_rcv +from .rcv import pmg_from_rcv -def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list: +def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]: """ - Brute-force the Smith set from a pairwise majority winner graph. + Find the Smith set from a pairwise majority graph. parameters --- - pairmaj_graph: dict[str, set[str]] + pmg: rwx.PyDiGraph A graph whose nodes correspond to candidates and (directed) edges show which candidates they beat pairwise. @@ -22,28 +23,18 @@ def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list: (single Majority winner), the Smith set will contain that single candidate. """ - candidates = set(pairmaj_graph.keys()) - size = len(candidates) + sccs = rwx.strongly_connected_components(pmg) - for size in range(1, len(candidates) + 1): - for sub in combinations(candidates, size): - subset = set(sub) - out = set(candidates) - subset + cg = rwx.condensation(pmg, sccs) - dom = True + src_sccs = [nd for nd in cg.node_indices() if cg.in_degree(nd) == 0] - for member in subset: - if not out.issubset(pairmaj_graph[member]): - dom = False - break + smith_set = sorted([c for scc in src_sccs for c in cg[scc]]) - if dom: - return sorted(subset) - - return [] + return smith_set -def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list: +def smith_set_from_rcv(ballots: pl.DataFrame) -> list: """ Compute the Smith set from a Ranked-Choice ballot. @@ -52,7 +43,7 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list: parameters --- - df : pl.DataFrame + ballots : pl.DataFrame A Polars DataFrame representing ballots. Each column is a candidate and each row is is a voter's ranking of the candidates. Lower numbers indicate higher preference (1 = top-choice). @@ -66,7 +57,8 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list: """ - return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots)) + # return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots)) + return ss_from_pmg(pmg_from_rcv(ballots)) def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list: diff --git a/src/smithy/rcv.py b/src/smithy/rcv.py index 94952e1..4c99c74 100644 --- a/src/smithy/rcv.py +++ b/src/smithy/rcv.py @@ -1,8 +1,9 @@ import polars as pl +import rustworkx as rwx from itertools import combinations -def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]: +def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph: """ Build a pairwise majority winner graph from a box of Ranked-Choice Ballots. @@ -15,27 +16,38 @@ def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]: returns --- - pairmaj_graph: dict[str, set[str]] + nodes: dict[str, int] + A dictionary of candidate names to associated node ids. + + pwm_graph: rwx.PyDiGraph A pairwise majority winner graph whose nodes correspond to candidates and (directed) edges show which candidates they beat pairwise. """ - candidates = rcv_ballots.columns + candidates = ballots.columns - pairmaj_graph: dict[str, set[str]] = {c: set() for c in candidates} + pmg = rwx.PyDiGraph() + nodes = {c: pmg.add_node(c) for c in candidates} - for a, b in combinations(candidates, 2): - result = rcv_ballots.select( + exprs = [] + pairs = list(combinations(candidates, 2)) + + for a, b in pairs: + exprs.extend( [ - (pl.col(a) < pl.col(b)).sum().alias("a_wins"), - (pl.col(b) < pl.col(a)).sum().alias("b_wins"), + (pl.col(a) < pl.col(b)).sum().alias(f"{a}>{b}"), + (pl.col(b) < pl.col(a)).sum().alias(f"{b}>{a}"), ] - ).row(0) + ) - a_wins, b_wins = result + results = ballots.select(exprs).row(0, named=True) + + for a, b in pairs: + a_wins = results[f"{a}>{b}"] + b_wins = results[f"{b}>{a}"] if a_wins > b_wins: - pairmaj_graph[a].add(b) + pmg.add_edge(nodes[a], nodes[b], a_wins - b_wins) elif b_wins > a_wins: - pairmaj_graph[b].add(a) + pmg.add_edge(nodes[b], nodes[a], b_wins - a_wins) - return pairmaj_graph + return pmg diff --git a/src/smithycmd.py b/src/smithycmd.py index dcd3840..3b58971 100644 --- a/src/smithycmd.py +++ b/src/smithycmd.py @@ -3,7 +3,8 @@ # dependencies = [ # "click>=8.4.1", # "rich>=15.0.0", -# "polars>=1.40.1" +# "polars>=1.40.1", +# "rustworkx>=0.17.1" # ] # /// import sys, io diff --git a/uv.lock b/uv.lock index 149db5e..416c4ae 100644 --- a/uv.lock +++ b/uv.lock @@ -1556,7 +1556,6 @@ name = "smithy" version = "0.1.0" source = { editable = "." } dependencies = [ - { name = "numpy" }, { name = "polars" }, { name = "rustworkx" }, ] @@ -1584,7 +1583,6 @@ dev = [ [package.metadata] requires-dist = [ { name = "click", marker = "extra == 'cli'", specifier = ">=8.4.0" }, - { name = "numpy", specifier = ">=2.4.6" }, { name = "polars", specifier = ">=1.40.1" }, { name = "rich", marker = "extra == 'cli'", specifier = ">=15.0.0" }, { name = "rustworkx", specifier = ">=0.17.1" },