feat: update legacy
- use sorted lut and bisect to optimize lut resolver - add circuit decuper - add signed diff for response item
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import enum
|
||||
import struct
|
||||
from functools import cached_property
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator
|
||||
from .common import DeviceKind, Circuit
|
||||
from typing import Iterable, Iterator
|
||||
from .common import DeviceKind, Circuit, CircuitDeviceScale
|
||||
|
||||
|
||||
class ResponsePriority(enum.Enum):
|
||||
@@ -34,24 +35,61 @@ class Request:
|
||||
"""The limited count of results."""
|
||||
|
||||
|
||||
class ResponseItem:
|
||||
class ResponseDeduperItem:
|
||||
"""
|
||||
The possible solution given by the resolver.
|
||||
The item for response deduplicator.
|
||||
"""
|
||||
|
||||
__circuit: Circuit
|
||||
"""The circuit of the response item."""
|
||||
__device_kind: DeviceKind
|
||||
"""The kind of device of this circuit."""
|
||||
__target_value: float
|
||||
"""The target value of this circuit."""
|
||||
"""The circuit of this deduplicator item."""
|
||||
|
||||
def __init__(
|
||||
self, circuit: Circuit, device_kind: DeviceKind, target_value: float
|
||||
) -> None:
|
||||
def __init__(self, circuit: Circuit) -> None:
|
||||
self.__circuit = circuit
|
||||
self.__device_kind = device_kind
|
||||
self.__target_value = target_value
|
||||
|
||||
__ONE_PACKER = struct.Struct("=id")
|
||||
__TWO_PACKER = struct.Struct("=iidd")
|
||||
__THREE_PACKER = struct.Struct("=iiiddd")
|
||||
|
||||
@cached_property
|
||||
def __uniform_circuit_presentation(self) -> bytes:
|
||||
c = self.__circuit
|
||||
|
||||
match c.device_scale:
|
||||
case CircuitDeviceScale.ONE:
|
||||
return self.__ONE_PACKER.pack(1, c.first_device_value)
|
||||
|
||||
case CircuitDeviceScale.TWO:
|
||||
v1, v2 = sorted([c.first_device_value, c.second_device_value])
|
||||
return self.__TWO_PACKER.pack(2, int(c.second_device_joint), v1, v2)
|
||||
|
||||
case CircuitDeviceScale.THREE:
|
||||
v1, v2, v3 = (
|
||||
c.first_device_value,
|
||||
c.second_device_value,
|
||||
c.third_device_value,
|
||||
)
|
||||
j2, j3 = int(c.second_device_joint), int(c.third_device_joint)
|
||||
|
||||
if j2 == j3:
|
||||
v1, v2, v3 = sorted([v1, v2, v3])
|
||||
else:
|
||||
v1, v2 = sorted([v1, v2])
|
||||
|
||||
return self.__THREE_PACKER.pack(3, j2, j3, v1, v2, v3)
|
||||
|
||||
@cached_property
|
||||
def __uniform_circuit_hash(self) -> int:
|
||||
return hash(self.__uniform_circuit_presentation)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, ResponseDeduperItem):
|
||||
return False
|
||||
return (
|
||||
self.__uniform_circuit_presentation == other.__uniform_circuit_presentation
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.__uniform_circuit_hash
|
||||
|
||||
@property
|
||||
def circuit(self) -> Circuit:
|
||||
@@ -62,7 +100,59 @@ class ResponseItem:
|
||||
"""
|
||||
return self.__circuit
|
||||
|
||||
@cached_property
|
||||
|
||||
class ResponseDeduper:
|
||||
"""
|
||||
The deduplicator for response circuits to deduplicate equivalent circuits.
|
||||
"""
|
||||
|
||||
__circuits: set[ResponseDeduperItem]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.__circuits = set()
|
||||
|
||||
def add(self, circuit: Circuit) -> None:
|
||||
self.__circuits.add(ResponseDeduperItem(circuit))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__circuits)
|
||||
|
||||
def __iter__(self) -> Iterator[Circuit]:
|
||||
return map(lambda x: x.circuit, self.__circuits)
|
||||
|
||||
|
||||
class ResponseItem:
|
||||
"""
|
||||
The possible solution given by the resolver.
|
||||
"""
|
||||
|
||||
__circuit: Circuit
|
||||
"""The circuit of the response item."""
|
||||
__value: float
|
||||
"""The value of the response circuit."""
|
||||
__difference: float
|
||||
"""The signed difference between the target value and the value of this circuit."""
|
||||
__relative_difference: float
|
||||
"""The signed relative difference between the target value and the value of this circuit."""
|
||||
|
||||
def __init__(
|
||||
self, circuit: Circuit, device_kind: DeviceKind, target_value: float
|
||||
) -> None:
|
||||
self.__circuit = circuit
|
||||
self.__value = self.__circuit.compute(device_kind)
|
||||
self.__difference = self.__value - target_value
|
||||
self.__relative_difference = self.__difference / target_value
|
||||
|
||||
@property
|
||||
def circuit(self) -> Circuit:
|
||||
"""
|
||||
The circuit of this response item.
|
||||
|
||||
:return: The circuit.
|
||||
"""
|
||||
return self.__circuit
|
||||
|
||||
@property
|
||||
def device_count(self) -> int:
|
||||
"""
|
||||
The device count of this circuit.
|
||||
@@ -71,32 +161,56 @@ class ResponseItem:
|
||||
"""
|
||||
return self.__circuit.device_scale.to_device_count()
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def value(self) -> float:
|
||||
"""
|
||||
The value of this circuit.
|
||||
|
||||
:return: The value.
|
||||
"""
|
||||
return self.__circuit.compute(self.__device_kind)
|
||||
return self.__value
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def difference(self) -> float:
|
||||
"""
|
||||
The absolute difference between the target value and the value of this circuit.
|
||||
The signed difference between the target value and the value of this circuit.
|
||||
|
||||
:return: The absolute difference.
|
||||
Positive value indicates that the value of this circuit is greater than the target value.
|
||||
Negative value indicates that the value of this circuit is less than the target value.
|
||||
|
||||
:return: The signed difference.
|
||||
"""
|
||||
return abs(self.__target_value - self.value)
|
||||
return self.__difference
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def unsigned_difference(self) -> float:
|
||||
"""
|
||||
The unsigned difference between the target value and the value of this circuit.
|
||||
|
||||
:return: The unsigned difference.
|
||||
"""
|
||||
return abs(self.__difference)
|
||||
|
||||
@property
|
||||
def relative_difference(self) -> float:
|
||||
"""
|
||||
The relative difference between the target value and the value of this circuit.
|
||||
The signed relative difference between the target value and the value of this circuit.
|
||||
|
||||
:return: The relative difference.
|
||||
Positive value indicates that the value of this circuit is greater than the target value.
|
||||
Negative value indicates that the value of this circuit is less than the target value.
|
||||
|
||||
:return: The signed relative difference.
|
||||
"""
|
||||
return self.difference / self.__target_value
|
||||
return self.__relative_difference
|
||||
|
||||
@property
|
||||
def unsigned_relative_difference(self) -> float:
|
||||
"""
|
||||
The unsigned relative difference between the target value and the value of this circuit.
|
||||
|
||||
:return: The unsigned relative difference.
|
||||
"""
|
||||
return abs(self.__relative_difference)
|
||||
|
||||
|
||||
class Response:
|
||||
@@ -111,7 +225,7 @@ class Response:
|
||||
__sorted_items: list[ResponseItem]
|
||||
"""The sorted items by priority and difference."""
|
||||
|
||||
def __init__(self, request: Request, candidates: Iterator[Circuit]) -> None:
|
||||
def __init__(self, request: Request, candidates: Iterable[Circuit]) -> None:
|
||||
self.__sorted_items = list(
|
||||
ResponseItem(item, request.device_kind, request.target_value)
|
||||
for item in candidates
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import heapq
|
||||
import bisect
|
||||
from itertools import chain, product
|
||||
from typing import Iterable, Iterator
|
||||
from functools import cached_property
|
||||
from .common import Resolver
|
||||
from ..dataset import DatasetCollection, Dataset
|
||||
from ..common import Circuit, DeviceKind, JointKind
|
||||
from ..query import Request, Response
|
||||
from ..query import Request, Response, ResponseDeduper
|
||||
|
||||
|
||||
class LutItem:
|
||||
@@ -36,106 +35,6 @@ class LutItem:
|
||||
return self.__circuit.compute(self.__device_kind)
|
||||
|
||||
|
||||
class ResultBucket(Iterable[LutItem]):
|
||||
"""
|
||||
A bounded bucket that keeps up to `N` LutItem entries with the smallest floats.
|
||||
|
||||
When the bucket is full, inserting a new item only succeeds if its float
|
||||
is less than the current maximum; the maximum is then evicted.
|
||||
"""
|
||||
|
||||
class ResultBucketItem:
|
||||
"""
|
||||
An item stored in a :class:`ResultBucket`.
|
||||
"""
|
||||
|
||||
__score: float
|
||||
"""The score associated with this item."""
|
||||
__item: LutItem
|
||||
"""The underlying LutItem."""
|
||||
__seq: int
|
||||
"""
|
||||
Monotonic counter used as a tiebreaker when scores are equal,
|
||||
ensuring that heapq never compares :class:`LutItem` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, score: float, item: LutItem, seq: int):
|
||||
self.__score = score
|
||||
self.__item = item
|
||||
self.__seq = seq
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
"""The score associated with this item."""
|
||||
return self.__score
|
||||
|
||||
@property
|
||||
def item(self) -> LutItem:
|
||||
"""The underlying LutItem."""
|
||||
return self.__item
|
||||
|
||||
def __lt__(self, other: "ResultBucket.ResultBucketItem") -> bool:
|
||||
# heapq is a min-heap: it always pops the smallest element.
|
||||
# We invert the comparison so that an item with a larger score
|
||||
# is considered "smaller", effectively turning the min-heap
|
||||
# into a max-heap (largest-score item at the top).
|
||||
if self.__score != other.__score:
|
||||
return self.__score > other.__score
|
||||
# Counter tiebreaker: when scores are equal the later-inserted
|
||||
# item (higher seq) is considered "smaller" and gets evicted first.
|
||||
return self.__seq > other.__seq
|
||||
|
||||
__n: int
|
||||
"""Maximum number of items the bucket can hold."""
|
||||
__heap: list[ResultBucketItem]
|
||||
"""
|
||||
Min-heap of :class:`ResultBucketItem`. The heap invariant is inverted
|
||||
via :meth:`ResultBucketItem.__lt__` so the entry with the largest score
|
||||
sits at index 0.
|
||||
"""
|
||||
__counter: int
|
||||
"""
|
||||
Monotonic counter fed to each :class:`ResultBucketItem` as a tiebreaker,
|
||||
preventing heapq from comparing :class:`LutItem` on score collisions.
|
||||
"""
|
||||
|
||||
def __init__(self, n: int):
|
||||
self.__n = n
|
||||
self.__heap = []
|
||||
self.__counter = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__heap)
|
||||
|
||||
def __iter__(self) -> Iterator[LutItem]:
|
||||
for entry in self.__heap:
|
||||
yield entry.item
|
||||
|
||||
def insert(self, item: LutItem, score: float) -> bool:
|
||||
"""
|
||||
Insert a :class:`LutItem` with the given score.
|
||||
|
||||
If the bucket is not yet full the item is always inserted.
|
||||
Otherwise the item is only inserted when *score* is smaller
|
||||
than the largest score currently in the bucket; the entry
|
||||
with the largest score is then evicted.
|
||||
|
||||
:param item: The LutItem to insert.
|
||||
:param score: The score associated with the item.
|
||||
:return: ``True`` if the item was inserted, ``False`` otherwise.
|
||||
"""
|
||||
entry = ResultBucket.ResultBucketItem(score, item, self.__counter)
|
||||
if len(self.__heap) < self.__n:
|
||||
heapq.heappush(self.__heap, entry)
|
||||
self.__counter += 1
|
||||
return True
|
||||
if score >= self.__heap[0].score:
|
||||
return False
|
||||
heapq.heapreplace(self.__heap, entry)
|
||||
self.__counter += 1
|
||||
return True
|
||||
|
||||
|
||||
class LutResolver(Resolver):
|
||||
"""
|
||||
A resolver that uses a lookup table to find the best matching circuit.
|
||||
@@ -149,21 +48,20 @@ class LutResolver(Resolver):
|
||||
"""The lookup table for inductors."""
|
||||
|
||||
def __init__(self, datasets: DatasetCollection):
|
||||
self.__resistor_lut = LutResolver.__build_lut(
|
||||
self.__resistor_lut = self.__build_lut(
|
||||
datasets.resistor_values, DeviceKind.RESISTOR
|
||||
)
|
||||
self.__capacitor_lut = LutResolver.__build_lut(
|
||||
self.__capacitor_lut = self.__build_lut(
|
||||
datasets.capacitor_values, DeviceKind.CAPACITOR
|
||||
)
|
||||
self.__inductor_lut = LutResolver.__build_lut(
|
||||
self.__inductor_lut = self.__build_lut(
|
||||
datasets.inductor_values, DeviceKind.INDUCTOR
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __build_lut(dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
|
||||
def __build_lut(self, dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
|
||||
values = dataset.values
|
||||
joints = tuple(JointKind)
|
||||
return [
|
||||
lut = [
|
||||
LutItem(circuit, device_kind)
|
||||
for circuit in chain(
|
||||
(Circuit.from_one_device(v1) for v1 in values),
|
||||
@@ -179,9 +77,10 @@ class LutResolver(Resolver):
|
||||
),
|
||||
)
|
||||
]
|
||||
lut.sort(key=lambda item: item.value)
|
||||
return lut
|
||||
|
||||
def resolve(self, request: Request) -> Response:
|
||||
# Fetch LUT by device kind
|
||||
lut: list[LutItem]
|
||||
match request.device_kind:
|
||||
case DeviceKind.RESISTOR:
|
||||
@@ -191,16 +90,50 @@ class LutResolver(Resolver):
|
||||
case DeviceKind.INDUCTOR:
|
||||
lut = self.__inductor_lut
|
||||
|
||||
# Check LUT item one by one
|
||||
bucket = ResultBucket(min(request.count_limit, 100))
|
||||
for item in lut:
|
||||
# compute absolute difference
|
||||
difference = abs(request.target_value - item.value)
|
||||
# If it is out of tolerance, skip it directly.
|
||||
if difference > request.tolerance:
|
||||
continue
|
||||
# put it into bucket
|
||||
bucket.insert(item, difference)
|
||||
target = request.target_value
|
||||
count_limit = min(request.count_limit, 100)
|
||||
deduper = ResponseDeduper()
|
||||
|
||||
# Return result
|
||||
return Response(request, map(lambda item: item.circuit, bucket))
|
||||
# Locate the insertion point of target in the sorted LUT.
|
||||
# left/right start at the two nearest neighbours and expand outward.
|
||||
idx = bisect.bisect_left(lut, target, key=lambda item: item.value)
|
||||
left = idx - 1
|
||||
right = idx
|
||||
|
||||
# Expand outward non-symmetrically: at each step compare the two
|
||||
# candidates on each side and advance the one that is closer to the
|
||||
# target. This guarantees items are visited in strictly increasing
|
||||
# difference order, so the first N items within tolerance are exactly
|
||||
# the N best matches.
|
||||
while left >= 0 or right < len(lut):
|
||||
if len(deduper) >= count_limit:
|
||||
break
|
||||
|
||||
if left < 0:
|
||||
go_left = False
|
||||
elif right >= len(lut):
|
||||
go_left = True
|
||||
else:
|
||||
go_left = (target - lut[left].value) <= (lut[right].value - target)
|
||||
|
||||
if go_left:
|
||||
item = lut[left]
|
||||
left -= 1
|
||||
else:
|
||||
item = lut[right]
|
||||
right += 1
|
||||
|
||||
diff = abs(target - item.value)
|
||||
# Since the LUT is sorted, values on each side only move further
|
||||
# from target as we advance. Once one side exceeds tolerance,
|
||||
# the rest of that side is guaranteed out of range — disable it.
|
||||
if diff > request.tolerance:
|
||||
if go_left:
|
||||
left = -1
|
||||
else:
|
||||
right = len(lut)
|
||||
continue
|
||||
|
||||
deduper.add(item.circuit)
|
||||
|
||||
return Response(request, deduper)
|
||||
|
||||
Reference in New Issue
Block a user