mirror of
https://github.com/tgorordo/smithy.git
synced 2026-06-05 16:22:15 -07:00
update to SCC algorithm, tidy internal names
This commit is contained in:
parent
fcf2505820
commit
674fdc1fe9
7 changed files with 42 additions and 45 deletions
3
justfile
3
justfile
|
|
@ -7,9 +7,6 @@ run *args:
|
||||||
marimo:
|
marimo:
|
||||||
uv run marimo --edit
|
uv run marimo --edit
|
||||||
|
|
||||||
example:
|
|
||||||
uv run python src/main.py test/test_ballot.csv
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
uv run ruff format src test
|
uv run ruff format src test
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ authors = [
|
||||||
]
|
]
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy>=2.4.6",
|
|
||||||
"polars>=1.40.1",
|
"polars>=1.40.1",
|
||||||
"rustworkx>=0.17.1",
|
"rustworkx>=0.17.1",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -96,9 +96,7 @@ nbformat==5.10.4
|
||||||
nodeenv==1.10.0
|
nodeenv==1.10.0
|
||||||
# via pyright
|
# via pyright
|
||||||
numpy==2.4.6
|
numpy==2.4.6
|
||||||
# via
|
# via rustworkx
|
||||||
# smithy (pyproject.toml)
|
|
||||||
# rustworkx
|
|
||||||
openai==2.37.0
|
openai==2.37.0
|
||||||
# via pydantic-ai-slim
|
# via pydantic-ai-slim
|
||||||
opentelemetry-api==1.42.0
|
opentelemetry-api==1.42.0
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,17 @@
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
import rustworkx as rwx
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
|
|
||||||
from .rcv import pairmaj_from_rcv
|
from .rcv import pmg_from_rcv
|
||||||
|
|
||||||
|
|
||||||
def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list:
|
def ss_from_pmg(pmg: rwx.PyDiGraph) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Brute-force the Smith set from a pairwise majority winner graph.
|
Find the Smith set from a pairwise majority graph.
|
||||||
|
|
||||||
parameters
|
parameters
|
||||||
---
|
---
|
||||||
pairmaj_graph: dict[str, set[str]]
|
pmg: rwx.PyDiGraph
|
||||||
A graph whose nodes correspond to candidates and (directed) edges show
|
A graph whose nodes correspond to candidates and (directed) edges show
|
||||||
which candidates they beat pairwise.
|
which candidates they beat pairwise.
|
||||||
|
|
||||||
|
|
@ -22,28 +23,18 @@ def smith_set_brutefrom_pairmaj(pairmaj_graph: dict[str, set[str]]) -> list:
|
||||||
(single Majority winner), the Smith set will contain that single candidate.
|
(single Majority winner), the Smith set will contain that single candidate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
candidates = set(pairmaj_graph.keys())
|
sccs = rwx.strongly_connected_components(pmg)
|
||||||
size = len(candidates)
|
|
||||||
|
|
||||||
for size in range(1, len(candidates) + 1):
|
cg = rwx.condensation(pmg, sccs)
|
||||||
for sub in combinations(candidates, size):
|
|
||||||
subset = set(sub)
|
|
||||||
out = set(candidates) - subset
|
|
||||||
|
|
||||||
dom = True
|
src_sccs = [nd for nd in cg.node_indices() if cg.in_degree(nd) == 0]
|
||||||
|
|
||||||
for member in subset:
|
smith_set = sorted([c for scc in src_sccs for c in cg[scc]])
|
||||||
if not out.issubset(pairmaj_graph[member]):
|
|
||||||
dom = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if dom:
|
return smith_set
|
||||||
return sorted(subset)
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
|
def smith_set_from_rcv(ballots: pl.DataFrame) -> list:
|
||||||
"""
|
"""
|
||||||
Compute the Smith set from a Ranked-Choice ballot.
|
Compute the Smith set from a Ranked-Choice ballot.
|
||||||
|
|
||||||
|
|
@ -52,7 +43,7 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
|
||||||
|
|
||||||
parameters
|
parameters
|
||||||
---
|
---
|
||||||
df : pl.DataFrame
|
ballots : pl.DataFrame
|
||||||
A Polars DataFrame representing ballots. Each column is a candidate and each
|
A Polars DataFrame representing ballots. Each column is a candidate and each
|
||||||
row is is a voter's ranking of the candidates. Lower numbers indicate higher
|
row is is a voter's ranking of the candidates. Lower numbers indicate higher
|
||||||
preference (1 = top-choice).
|
preference (1 = top-choice).
|
||||||
|
|
@ -66,7 +57,8 @@ def smith_set_from_rcv(rcv_ballots: pl.DataFrame) -> list:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots))
|
# return smith_set_brutefrom_pairmaj(pairmaj_from_rcv(rcv_ballots))
|
||||||
|
return ss_from_pmg(pmg_from_rcv(ballots))
|
||||||
|
|
||||||
|
|
||||||
def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list:
|
def smith_set(df: pl.DataFrame, ballotkind="rcv") -> list:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
import rustworkx as rwx
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
|
|
||||||
|
|
||||||
def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]:
|
def pmg_from_rcv(ballots: pl.DataFrame) -> rwx.PyDiGraph:
|
||||||
"""
|
"""
|
||||||
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
Build a pairwise majority winner graph from a box of Ranked-Choice Ballots.
|
||||||
|
|
||||||
|
|
@ -15,27 +16,38 @@ def pairmaj_from_rcv(rcv_ballots: pl.DataFrame) -> dict[str, set[str]]:
|
||||||
|
|
||||||
returns
|
returns
|
||||||
---
|
---
|
||||||
pairmaj_graph: dict[str, set[str]]
|
nodes: dict[str, int]
|
||||||
|
A dictionary of candidate names to associated node ids.
|
||||||
|
|
||||||
|
pwm_graph: rwx.PyDiGraph
|
||||||
A pairwise majority winner graph whose nodes correspond to candidates and
|
A pairwise majority winner graph whose nodes correspond to candidates and
|
||||||
(directed) edges show which candidates they beat pairwise.
|
(directed) edges show which candidates they beat pairwise.
|
||||||
"""
|
"""
|
||||||
candidates = rcv_ballots.columns
|
candidates = ballots.columns
|
||||||
|
|
||||||
pairmaj_graph: dict[str, set[str]] = {c: set() for c in candidates}
|
pmg = rwx.PyDiGraph()
|
||||||
|
nodes = {c: pmg.add_node(c) for c in candidates}
|
||||||
|
|
||||||
for a, b in combinations(candidates, 2):
|
exprs = []
|
||||||
result = rcv_ballots.select(
|
pairs = list(combinations(candidates, 2))
|
||||||
|
|
||||||
|
for a, b in pairs:
|
||||||
|
exprs.extend(
|
||||||
[
|
[
|
||||||
(pl.col(a) < pl.col(b)).sum().alias("a_wins"),
|
(pl.col(a) < pl.col(b)).sum().alias(f"{a}>{b}"),
|
||||||
(pl.col(b) < pl.col(a)).sum().alias("b_wins"),
|
(pl.col(b) < pl.col(a)).sum().alias(f"{b}>{a}"),
|
||||||
]
|
]
|
||||||
).row(0)
|
)
|
||||||
|
|
||||||
a_wins, b_wins = result
|
results = ballots.select(exprs).row(0, named=True)
|
||||||
|
|
||||||
|
for a, b in pairs:
|
||||||
|
a_wins = results[f"{a}>{b}"]
|
||||||
|
b_wins = results[f"{b}>{a}"]
|
||||||
|
|
||||||
if a_wins > b_wins:
|
if a_wins > b_wins:
|
||||||
pairmaj_graph[a].add(b)
|
pmg.add_edge(nodes[a], nodes[b], a_wins - b_wins)
|
||||||
elif b_wins > a_wins:
|
elif b_wins > a_wins:
|
||||||
pairmaj_graph[b].add(a)
|
pmg.add_edge(nodes[b], nodes[a], b_wins - a_wins)
|
||||||
|
|
||||||
return pairmaj_graph
|
return pmg
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "click>=8.4.1",
|
# "click>=8.4.1",
|
||||||
# "rich>=15.0.0",
|
# "rich>=15.0.0",
|
||||||
# "polars>=1.40.1"
|
# "polars>=1.40.1",
|
||||||
|
# "rustworkx>=0.17.1"
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
import sys, io
|
import sys, io
|
||||||
|
|
|
||||||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -1556,7 +1556,6 @@ name = "smithy"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "numpy" },
|
|
||||||
{ name = "polars" },
|
{ name = "polars" },
|
||||||
{ name = "rustworkx" },
|
{ name = "rustworkx" },
|
||||||
]
|
]
|
||||||
|
|
@ -1584,7 +1583,6 @@ dev = [
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "click", marker = "extra == 'cli'", specifier = ">=8.4.0" },
|
{ name = "click", marker = "extra == 'cli'", specifier = ">=8.4.0" },
|
||||||
{ name = "numpy", specifier = ">=2.4.6" },
|
|
||||||
{ name = "polars", specifier = ">=1.40.1" },
|
{ name = "polars", specifier = ">=1.40.1" },
|
||||||
{ name = "rich", marker = "extra == 'cli'", specifier = ">=15.0.0" },
|
{ name = "rich", marker = "extra == 'cli'", specifier = ">=15.0.0" },
|
||||||
{ name = "rustworkx", specifier = ">=0.17.1" },
|
{ name = "rustworkx", specifier = ">=0.17.1" },
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue