1
0

feat: finish lut resolver

This commit is contained in:
2026-06-15 15:58:04 +08:00
parent 7721672b8e
commit ed8f5e1943
6 changed files with 312 additions and 238 deletions

View File

@@ -1,38 +1,10 @@
import struct
from typing import Iterator, BinaryIO
from pathlib import Path
from .common import Resolver, ResolverRequest, ResolverResult, ResultPriority
from ..dataset import DatasetCollection
from ..common import Circuit, SubCircuit, JointKind, LcrConnException
class LutResolver(Resolver):
"""
A resolver that uses a lookup table to find the best matching circuit.
"""
lut: tuple[Circuit]
def __init__(self, lut: tuple[Circuit]):
self.lut = lut
@staticmethod
def from_dataset(dataset: DatasetCollection) -> 'LutResolver':
pass
@staticmethod
def from_cache(filename: Path) -> 'LutResolver':
with open(filename, "rb") as f:
cnt = _read_int(f)
return LutResolver(tuple(LutItem.from_cache(f) for _ in range(cnt)))
def save_as_cache(self, filename: Path) -> None:
with open(filename, "wb") as f:
_write_int(f, len(self.lut))
for item in self.lut:
item.save_as_cache(f)
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]:
pass
import heapq
from itertools import chain, product
from typing import Iterable, Iterator
from functools import cached_property
from .common import Resolver, ResolverRequest, ResultPriority
from ..dataset import DatasetCollection, Dataset
from ..common import Circuit, DeviceKind, JointKind, LcrConnException
class LutItem:
@@ -40,70 +12,194 @@ class LutItem:
An item in the lookup table.
"""
circuit: Circuit
__circuit: Circuit
"""The circuit represented by this item."""
__value_cache: float | None
"""The cached computed value of the circuit, or None if it has not been cached yet."""
__device_kind: DeviceKind
"""The device kind applied for this circuit."""
def __init__(self, circuit: Circuit):
self.circuit = circuit
def __init__(self, circuit: Circuit, device_kind: DeviceKind):
self.__circuit = circuit
self.__device_kind = device_kind
@property
def circuit(self) -> Circuit:
return self.__circuit
@cached_property
def value(self) -> float:
"""
The computed value of the circuit.
:return: The computed value.
"""
return self.__circuit.compute(self.__device_kind)
class ResultBucket(Iterable[LutItem]):
"""
A bounded bucket that keeps up to `N` LutItem entries with the smallest floats.
When the bucket is full, inserting a new item only succeeds if its float
is less than the current maximum; the maximum is then evicted.
"""
class ResultBucketItem:
"""
An item stored in a :class:`ResultBucket`.
"""
__score: float
"""The score associated with this item."""
__item: LutItem
"""The underlying LutItem."""
__seq: int
"""
Monotonic counter used as a tiebreaker when scores are equal,
ensuring that heapq never compares :class:`LutItem` directly.
"""
def __init__(self, score: float, item: LutItem, seq: int):
self.__score = score
self.__item = item
self.__seq = seq
@property
def score(self) -> float:
"""The score associated with this item."""
return self.__score
@property
def item(self) -> LutItem:
"""The underlying LutItem."""
return self.__item
def __lt__(self, other: 'ResultBucket.ResultBucketItem') -> bool:
# heapq is a min-heap: it always pops the smallest element.
# We invert the comparison so that an item with a larger score
# is considered "smaller", effectively turning the min-heap
# into a max-heap (largest-score item at the top).
if self.__score != other.__score:
return self.__score > other.__score
# Counter tiebreaker: when scores are equal the later-inserted
# item (higher seq) is considered "smaller" and gets evicted first.
return self.__seq > other.__seq
__n: int
"""Maximum number of items the bucket can hold."""
__heap: list[ResultBucketItem]
"""
Min-heap of :class:`ResultBucketItem`. The heap invariant is inverted
via :meth:`ResultBucketItem.__lt__` so the entry with the largest score
sits at index 0.
"""
__counter: int
"""
Monotonic counter fed to each :class:`ResultBucketItem` as a tiebreaker,
preventing heapq from comparing :class:`LutItem` on score collisions.
"""
def __init__(self, n: int):
self.__n = n
self.__heap = []
self.__counter = 0
def __len__(self) -> int:
return len(self.__heap)
def __iter__(self) -> Iterator[LutItem]:
for entry in self.__heap:
yield entry.item
def insert(self, item: LutItem, score: float) -> bool:
"""
Insert a :class:`LutItem` with the given score.
If the bucket is not yet full the item is always inserted.
Otherwise the item is only inserted when *score* is smaller
than the largest score currently in the bucket; the entry
with the largest score is then evicted.
:param item: The LutItem to insert.
:param score: The score associated with the item.
:return: ``True`` if the item was inserted, ``False`` otherwise.
"""
entry = ResultBucket.ResultBucketItem(score, item, self.__counter)
if len(self.__heap) < self.__n:
heapq.heappush(self.__heap, entry)
self.__counter += 1
return True
if score >= self.__heap[0].score:
return False
heapq.heapreplace(self.__heap, entry)
self.__counter += 1
return True
class LutResolver(Resolver):
"""
A resolver that uses a lookup table to find the best matching circuit.
"""
__resistor_lut: list[LutItem]
"""The lookup table for resistors."""
__capacitor_lut: list[LutItem]
"""The lookup table for capacitors."""
__inductor_lut: list[LutItem]
"""The lookup table for inductors."""
def __init__(self, datasets: DatasetCollection):
self.__resistor_lut = LutResolver.__build_lut(
datasets.resistor_values, DeviceKind.RESISTOR
)
self.__capacitor_lut = LutResolver.__build_lut(
datasets.capacitor_values, DeviceKind.CAPACITOR
)
self.__inductor_lut = LutResolver.__build_lut(
datasets.inductor_values, DeviceKind.INDUCTOR
)
@staticmethod
def from_cache(f: BinaryIO) -> 'LutItem':
cnt = _read_int(f)
if cnt < 1:
raise LcrConnException("Invalid circuit count in LUT item")
device_value = _read_double(f)
circuit = Circuit(device_value)
cnt -= 1
for _ in range(cnt):
j = JointKind.SERIES if _read_bool(f) else JointKind.PARALLEL
dev = _read_double(f)
joint = SubCircuit(j, dev)
circuit.add_joint(joint)
return LutItem(circuit)
def save_as_cache(self, f: BinaryIO) -> None:
_write_int(f, self.circuit.len_devices())
_write_double(f, self.circuit.__first_device_value)
for joint in self.circuit.joints():
_write_bool(f, joint.kind == JointKind.SERIES)
_write_double(f, joint.value)
def compute(self) -> float:
"""The computed value of the circuit."""
if self.__value_cache is None:
self.__value_cache = self.circuit.value()
return self.__value_cache
DOUBLE_PACKER = struct.Struct("d")
INT_PACKER = struct.Struct("I")
BOOL_PACKER = struct.Struct("?")
def _read_double(fs) -> float:
return DOUBLE_PACKER.unpack(fs.read(DOUBLE_PACKER.size))[0]
def _read_int(fs) -> int:
return INT_PACKER.unpack(fs.read(INT_PACKER.size))[0]
def _read_bool(fs) -> bool:
return BOOL_PACKER.unpack(fs.read(BOOL_PACKER.size))[0]
def _write_double(fs, num: float):
fs.write(DOUBLE_PACKER.pack(num))
def _write_int(fs, num: int):
fs.write(INT_PACKER.pack(num))
def _write_bool(fs, num: bool):
fs.write(BOOL_PACKER.pack(num))
def __build_lut(dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
values = dataset.values
joints = tuple(JointKind)
return [
LutItem(circuit, device_kind)
for circuit in chain(
(Circuit.from_one_device(v1) for v1 in values),
(
Circuit.from_two_devices(v1, v2, j2)
for v1, v2, j2 in product(values, values, joints)
),
(
Circuit.from_three_devices(v1, v2, j2, v3, j3)
for v1, v2, j2, v3, j3 in product(
values, values, joints, values, joints
)
),
)
]
def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
# Fetch LUT by device kind
lut: list[LutItem]
match request.device_kind:
case DeviceKind.RESISTOR:
lut = self.__resistor_lut
case DeviceKind.CAPACITOR:
lut = self.__capacitor_lut
case DeviceKind.INDUCTOR:
lut = self.__inductor_lut
# Check LUT item one by one
bucket = ResultBucket(min(request.count_limit, 100))
for item in lut:
# compute absolute difference
difference = abs(request.target_value - item.value)
# If it is out of tolerance, skip it directly.
if difference > request.tolerance:
continue
# put it into bucket
bucket.insert(item, difference)
# Return result
return map(lambda item: item.circuit, bucket)