# extrai/core/conflict_resolvers.py
from collections import Counter
from collections.abc import Callable
from difflib import SequenceMatcher
from typing import Any
from extrai.utils.flattening_utils import JSONValue, Path
# Define conflict resolution strategies
ConflictResolutionStrategy = Callable[
[Path, list[JSONValue], list[float] | None], JSONValue | None
]
[docs]
def default_conflict_resolver(
path: Path, values: list[JSONValue], weights: list[float] | None = None
) -> JSONValue | None:
"""
Default conflict resolution: if no consensus, omit the field.
"""
return None
[docs]
def prefer_most_common_resolver(
_path: Path, values: list[JSONValue], weights: list[float] | None = None
) -> JSONValue | None:
"""
Conflict resolution: prefer the most common value.
If weights are provided, prefers the value with the highest total weight.
"""
if not values:
return None
if weights and len(weights) == len(values):
# Weighted voting
weighted_counts: dict[Any, float] = {}
# We need to handle unhashable types (like dicts/lists) if they appear in values
# But JSONValue can be complex. Typically conflict resolution is on leaves (primitives).
# Flattening utils usually produce primitives at leaves, but lists can be values if not recursed?
# Assuming primitives for now (str, int, float, bool, None).
for val, w in zip(values, weights):
# If val is unhashable, we can't key it easily.
# Fallback to string repr or identity if needed, but for now assume hashable.
try:
weighted_counts[val] = weighted_counts.get(val, 0.0) + w
except TypeError:
# Unhashable type (e.g. list), skip optimization or use repr
# For safety, let's just pick the first one if we can't count.
# Or convert to tuple?
# Let's rely on standard Counter behavior for fallback.
pass
if weighted_counts:
# Pick value with max weight
# Break ties by first occurrence (insertion order in weighted_counts)
most_common_value = max(weighted_counts, key=weighted_counts.get)
return most_common_value
# Fallback to unweighted count
# Note: Counter works with unhashable types? No.
# If values contains unhashables, Counter(values) raises TypeError.
# We should handle that, but original code assumed they work or didn't handle lists as values?
# flattening_utils unflattening implies values are leaves.
try:
count = Counter(values)
most_common_value, _ = count.most_common(1)[0]
return most_common_value
except TypeError:
# Fallback for unhashable
return values[0]
[docs]
def levenshtein_similarity(a: str, b: str) -> float:
return SequenceMatcher(None, a, b).ratio()
[docs]
class SimilarityClusterResolver:
"""
Resolves conflicts by clustering values based on string similarity.
Useful for filtering out outliers (e.g. "War" vs "Christmas", "Gifts").
"""
def __init__(
self,
similarity_threshold: float = 0.6,
scorer: Callable[[str, str], float] = levenshtein_similarity,
):
self.similarity_threshold = similarity_threshold
self.scorer = scorer
def __call__(
self, path: Path, values: list[JSONValue], weights: list[float] | None = None
) -> JSONValue | None:
if not values:
return None
# Only applicable if values are strings
if not all(isinstance(v, str) for v in values):
return prefer_most_common_resolver(path, values, weights)
# 1. Compute pairwise similarities and build adjacency list
n = len(values)
adj = {i: [] for i in range(n)}
for i in range(n):
for j in range(i + 1, n):
score = self.scorer(values[i], values[j])
if score >= self.similarity_threshold:
adj[i].append(j)
adj[j].append(i)
# 2. Find connected components (clusters)
visited = set()
clusters = []
for i in range(n):
if i not in visited:
component = []
stack = [i]
visited.add(i)
while stack:
node = stack.pop()
component.append(node)
for neighbor in adj[node]:
if neighbor not in visited:
visited.add(neighbor)
stack.append(neighbor)
clusters.append(component)
if not clusters:
return prefer_most_common_resolver(path, values, weights)
# 3. Find the best cluster
# If weights are provided, pick the cluster with the highest total weight.
# Otherwise, pick the largest cluster.
if weights and len(weights) == n:
def cluster_weight(indices):
return sum(weights[i] for i in indices)
best_cluster_indices = max(clusters, key=cluster_weight)
else:
best_cluster_indices = max(clusters, key=len)
# 4. Pick the representative from the best cluster
cluster_values = [values[i] for i in best_cluster_indices]
cluster_weights = (
[weights[i] for i in best_cluster_indices] if weights else None
)
return prefer_most_common_resolver(path, cluster_values, cluster_weights)