1
0

feat: update legacy

- use sorted lut and bisect to optimize lut resolver
- add circuit decuper
- add signed diff for response item
This commit is contained in:
2026-06-17 19:49:54 +08:00
parent 96fa6263a8
commit d42885f1ab
2 changed files with 196 additions and 149 deletions

View File

@@ -1,8 +1,9 @@
import enum
import struct
from functools import cached_property
from dataclasses import dataclass
from typing import Iterator
from .common import DeviceKind, Circuit
from typing import Iterable, Iterator
from .common import DeviceKind, Circuit, CircuitDeviceScale
class ResponsePriority(enum.Enum):
@@ -34,24 +35,61 @@ class Request:
"""The limited count of results."""
class ResponseItem:
class ResponseDeduperItem:
"""
The possible solution given by the resolver.
The item for response deduplicator.
"""
__circuit: Circuit
"""The circuit of the response item."""
__device_kind: DeviceKind
"""The kind of device of this circuit."""
__target_value: float
"""The target value of this circuit."""
"""The circuit of this deduplicator item."""
def __init__(
self, circuit: Circuit, device_kind: DeviceKind, target_value: float
) -> None:
def __init__(self, circuit: Circuit) -> None:
self.__circuit = circuit
self.__device_kind = device_kind
self.__target_value = target_value
__ONE_PACKER = struct.Struct("=id")
__TWO_PACKER = struct.Struct("=iidd")
__THREE_PACKER = struct.Struct("=iiiddd")
@cached_property
def __uniform_circuit_presentation(self) -> bytes:
c = self.__circuit
match c.device_scale:
case CircuitDeviceScale.ONE:
return self.__ONE_PACKER.pack(1, c.first_device_value)
case CircuitDeviceScale.TWO:
v1, v2 = sorted([c.first_device_value, c.second_device_value])
return self.__TWO_PACKER.pack(2, int(c.second_device_joint), v1, v2)
case CircuitDeviceScale.THREE:
v1, v2, v3 = (
c.first_device_value,
c.second_device_value,
c.third_device_value,
)
j2, j3 = int(c.second_device_joint), int(c.third_device_joint)
if j2 == j3:
v1, v2, v3 = sorted([v1, v2, v3])
else:
v1, v2 = sorted([v1, v2])
return self.__THREE_PACKER.pack(3, j2, j3, v1, v2, v3)
@cached_property
def __uniform_circuit_hash(self) -> int:
return hash(self.__uniform_circuit_presentation)
def __eq__(self, other: object) -> bool:
if not isinstance(other, ResponseDeduperItem):
return False
return (
self.__uniform_circuit_presentation == other.__uniform_circuit_presentation
)
def __hash__(self) -> int:
return self.__uniform_circuit_hash
@property
def circuit(self) -> Circuit:
@@ -62,7 +100,59 @@ class ResponseItem:
"""
return self.__circuit
@cached_property
class ResponseDeduper:
"""
The deduplicator for response circuits to deduplicate equivalent circuits.
"""
__circuits: set[ResponseDeduperItem]
def __init__(self) -> None:
self.__circuits = set()
def add(self, circuit: Circuit) -> None:
self.__circuits.add(ResponseDeduperItem(circuit))
def __len__(self) -> int:
return len(self.__circuits)
def __iter__(self) -> Iterator[Circuit]:
return map(lambda x: x.circuit, self.__circuits)
class ResponseItem:
"""
The possible solution given by the resolver.
"""
__circuit: Circuit
"""The circuit of the response item."""
__value: float
"""The value of the response circuit."""
__difference: float
"""The signed difference between the target value and the value of this circuit."""
__relative_difference: float
"""The signed relative difference between the target value and the value of this circuit."""
def __init__(
self, circuit: Circuit, device_kind: DeviceKind, target_value: float
) -> None:
self.__circuit = circuit
self.__value = self.__circuit.compute(device_kind)
self.__difference = self.__value - target_value
self.__relative_difference = self.__difference / target_value
@property
def circuit(self) -> Circuit:
"""
The circuit of this response item.
:return: The circuit.
"""
return self.__circuit
@property
def device_count(self) -> int:
"""
The device count of this circuit.
@@ -71,32 +161,56 @@ class ResponseItem:
"""
return self.__circuit.device_scale.to_device_count()
@cached_property
@property
def value(self) -> float:
"""
The value of this circuit.
:return: The value.
"""
return self.__circuit.compute(self.__device_kind)
return self.__value
@cached_property
@property
def difference(self) -> float:
"""
The absolute difference between the target value and the value of this circuit.
The signed difference between the target value and the value of this circuit.
:return: The absolute difference.
Positive value indicates that the value of this circuit is greater than the target value.
Negative value indicates that the value of this circuit is less than the target value.
:return: The signed difference.
"""
return abs(self.__target_value - self.value)
return self.__difference
@cached_property
@property
def unsigned_difference(self) -> float:
"""
The unsigned difference between the target value and the value of this circuit.
:return: The unsigned difference.
"""
return abs(self.__difference)
@property
def relative_difference(self) -> float:
"""
The relative difference between the target value and the value of this circuit.
The signed relative difference between the target value and the value of this circuit.
:return: The relative difference.
Positive value indicates that the value of this circuit is greater than the target value.
Negative value indicates that the value of this circuit is less than the target value.
:return: The signed relative difference.
"""
return self.difference / self.__target_value
return self.__relative_difference
@property
def unsigned_relative_difference(self) -> float:
"""
The unsigned relative difference between the target value and the value of this circuit.
:return: The unsigned relative difference.
"""
return abs(self.__relative_difference)
class Response:
@@ -111,7 +225,7 @@ class Response:
__sorted_items: list[ResponseItem]
"""The sorted items by priority and difference."""
def __init__(self, request: Request, candidates: Iterator[Circuit]) -> None:
def __init__(self, request: Request, candidates: Iterable[Circuit]) -> None:
self.__sorted_items = list(
ResponseItem(item, request.device_kind, request.target_value)
for item in candidates
@@ -125,7 +239,7 @@ class Response:
self.__sorted_items.sort(key=lambda x: x.difference)
# Cut item by limit
self.__sorted_items = self.__sorted_items[:request.count_limit]
self.__sorted_items = self.__sorted_items[: request.count_limit]
def __getitem__(self, index: int) -> ResponseItem:
return self.__sorted_items[index]

View File

@@ -1,11 +1,10 @@
import heapq
import bisect
from itertools import chain, product
from typing import Iterable, Iterator
from functools import cached_property
from .common import Resolver
from ..dataset import DatasetCollection, Dataset
from ..common import Circuit, DeviceKind, JointKind
from ..query import Request, Response
from ..query import Request, Response, ResponseDeduper
class LutItem:
@@ -36,106 +35,6 @@ class LutItem:
return self.__circuit.compute(self.__device_kind)
class ResultBucket(Iterable[LutItem]):
"""
A bounded bucket that keeps up to `N` LutItem entries with the smallest floats.
When the bucket is full, inserting a new item only succeeds if its float
is less than the current maximum; the maximum is then evicted.
"""
class ResultBucketItem:
"""
An item stored in a :class:`ResultBucket`.
"""
__score: float
"""The score associated with this item."""
__item: LutItem
"""The underlying LutItem."""
__seq: int
"""
Monotonic counter used as a tiebreaker when scores are equal,
ensuring that heapq never compares :class:`LutItem` directly.
"""
def __init__(self, score: float, item: LutItem, seq: int):
self.__score = score
self.__item = item
self.__seq = seq
@property
def score(self) -> float:
"""The score associated with this item."""
return self.__score
@property
def item(self) -> LutItem:
"""The underlying LutItem."""
return self.__item
def __lt__(self, other: "ResultBucket.ResultBucketItem") -> bool:
# heapq is a min-heap: it always pops the smallest element.
# We invert the comparison so that an item with a larger score
# is considered "smaller", effectively turning the min-heap
# into a max-heap (largest-score item at the top).
if self.__score != other.__score:
return self.__score > other.__score
# Counter tiebreaker: when scores are equal the later-inserted
# item (higher seq) is considered "smaller" and gets evicted first.
return self.__seq > other.__seq
__n: int
"""Maximum number of items the bucket can hold."""
__heap: list[ResultBucketItem]
"""
Min-heap of :class:`ResultBucketItem`. The heap invariant is inverted
via :meth:`ResultBucketItem.__lt__` so the entry with the largest score
sits at index 0.
"""
__counter: int
"""
Monotonic counter fed to each :class:`ResultBucketItem` as a tiebreaker,
preventing heapq from comparing :class:`LutItem` on score collisions.
"""
def __init__(self, n: int):
self.__n = n
self.__heap = []
self.__counter = 0
def __len__(self) -> int:
return len(self.__heap)
def __iter__(self) -> Iterator[LutItem]:
for entry in self.__heap:
yield entry.item
def insert(self, item: LutItem, score: float) -> bool:
"""
Insert a :class:`LutItem` with the given score.
If the bucket is not yet full the item is always inserted.
Otherwise the item is only inserted when *score* is smaller
than the largest score currently in the bucket; the entry
with the largest score is then evicted.
:param item: The LutItem to insert.
:param score: The score associated with the item.
:return: ``True`` if the item was inserted, ``False`` otherwise.
"""
entry = ResultBucket.ResultBucketItem(score, item, self.__counter)
if len(self.__heap) < self.__n:
heapq.heappush(self.__heap, entry)
self.__counter += 1
return True
if score >= self.__heap[0].score:
return False
heapq.heapreplace(self.__heap, entry)
self.__counter += 1
return True
class LutResolver(Resolver):
"""
A resolver that uses a lookup table to find the best matching circuit.
@@ -149,21 +48,20 @@ class LutResolver(Resolver):
"""The lookup table for inductors."""
def __init__(self, datasets: DatasetCollection):
self.__resistor_lut = LutResolver.__build_lut(
self.__resistor_lut = self.__build_lut(
datasets.resistor_values, DeviceKind.RESISTOR
)
self.__capacitor_lut = LutResolver.__build_lut(
self.__capacitor_lut = self.__build_lut(
datasets.capacitor_values, DeviceKind.CAPACITOR
)
self.__inductor_lut = LutResolver.__build_lut(
self.__inductor_lut = self.__build_lut(
datasets.inductor_values, DeviceKind.INDUCTOR
)
@staticmethod
def __build_lut(dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
def __build_lut(self, dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
values = dataset.values
joints = tuple(JointKind)
return [
lut = [
LutItem(circuit, device_kind)
for circuit in chain(
(Circuit.from_one_device(v1) for v1 in values),
@@ -179,9 +77,10 @@ class LutResolver(Resolver):
),
)
]
lut.sort(key=lambda item: item.value)
return lut
def resolve(self, request: Request) -> Response:
# Fetch LUT by device kind
lut: list[LutItem]
match request.device_kind:
case DeviceKind.RESISTOR:
@@ -191,16 +90,50 @@ class LutResolver(Resolver):
case DeviceKind.INDUCTOR:
lut = self.__inductor_lut
# Check LUT item one by one
bucket = ResultBucket(min(request.count_limit, 100))
for item in lut:
# compute absolute difference
difference = abs(request.target_value - item.value)
# If it is out of tolerance, skip it directly.
if difference > request.tolerance:
continue
# put it into bucket
bucket.insert(item, difference)
target = request.target_value
count_limit = min(request.count_limit, 100)
deduper = ResponseDeduper()
# Return result
return Response(request, map(lambda item: item.circuit, bucket))
# Locate the insertion point of target in the sorted LUT.
# left/right start at the two nearest neighbours and expand outward.
idx = bisect.bisect_left(lut, target, key=lambda item: item.value)
left = idx - 1
right = idx
# Expand outward non-symmetrically: at each step compare the two
# candidates on each side and advance the one that is closer to the
# target. This guarantees items are visited in strictly increasing
# difference order, so the first N items within tolerance are exactly
# the N best matches.
while left >= 0 or right < len(lut):
if len(deduper) >= count_limit:
break
if left < 0:
go_left = False
elif right >= len(lut):
go_left = True
else:
go_left = (target - lut[left].value) <= (lut[right].value - target)
if go_left:
item = lut[left]
left -= 1
else:
item = lut[right]
right += 1
diff = abs(target - item.value)
# Since the LUT is sorted, values on each side only move further
# from target as we advance. Once one side exceeds tolerance,
# the rest of that side is guaranteed out of range — disable it.
if diff > request.tolerance:
if go_left:
left = -1
else:
right = len(lut)
continue
deduper.add(item.circuit)
return Response(request, deduper)