feat: finish lut resolver
This commit is contained in:
@@ -1,38 +1,10 @@
|
||||
import struct
|
||||
from typing import Iterator, BinaryIO
|
||||
from pathlib import Path
|
||||
from .common import Resolver, ResolverRequest, ResolverResult, ResultPriority
|
||||
from ..dataset import DatasetCollection
|
||||
from ..common import Circuit, SubCircuit, JointKind, LcrConnException
|
||||
|
||||
class LutResolver(Resolver):
|
||||
"""
|
||||
A resolver that uses a lookup table to find the best matching circuit.
|
||||
"""
|
||||
|
||||
lut: tuple[Circuit]
|
||||
|
||||
def __init__(self, lut: tuple[Circuit]):
|
||||
self.lut = lut
|
||||
|
||||
@staticmethod
|
||||
def from_dataset(dataset: DatasetCollection) -> 'LutResolver':
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def from_cache(filename: Path) -> 'LutResolver':
|
||||
with open(filename, "rb") as f:
|
||||
cnt = _read_int(f)
|
||||
return LutResolver(tuple(LutItem.from_cache(f) for _ in range(cnt)))
|
||||
|
||||
def save_as_cache(self, filename: Path) -> None:
|
||||
with open(filename, "wb") as f:
|
||||
_write_int(f, len(self.lut))
|
||||
for item in self.lut:
|
||||
item.save_as_cache(f)
|
||||
|
||||
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]:
|
||||
pass
|
||||
import heapq
|
||||
from itertools import chain, product
|
||||
from typing import Iterable, Iterator
|
||||
from functools import cached_property
|
||||
from .common import Resolver, ResolverRequest, ResultPriority
|
||||
from ..dataset import DatasetCollection, Dataset
|
||||
from ..common import Circuit, DeviceKind, JointKind, LcrConnException
|
||||
|
||||
|
||||
class LutItem:
|
||||
@@ -40,70 +12,194 @@ class LutItem:
|
||||
An item in the lookup table.
|
||||
"""
|
||||
|
||||
circuit: Circuit
|
||||
__circuit: Circuit
|
||||
"""The circuit represented by this item."""
|
||||
__value_cache: float | None
|
||||
"""The cached computed value of the circuit, or None if it has not been cached yet."""
|
||||
__device_kind: DeviceKind
|
||||
"""The device kind applied for this circuit."""
|
||||
|
||||
def __init__(self, circuit: Circuit):
|
||||
self.circuit = circuit
|
||||
def __init__(self, circuit: Circuit, device_kind: DeviceKind):
|
||||
self.__circuit = circuit
|
||||
self.__device_kind = device_kind
|
||||
|
||||
@property
|
||||
def circuit(self) -> Circuit:
|
||||
return self.__circuit
|
||||
|
||||
@cached_property
|
||||
def value(self) -> float:
|
||||
"""
|
||||
The computed value of the circuit.
|
||||
|
||||
:return: The computed value.
|
||||
"""
|
||||
return self.__circuit.compute(self.__device_kind)
|
||||
|
||||
|
||||
class ResultBucket(Iterable[LutItem]):
|
||||
"""
|
||||
A bounded bucket that keeps up to `N` LutItem entries with the smallest floats.
|
||||
|
||||
When the bucket is full, inserting a new item only succeeds if its float
|
||||
is less than the current maximum; the maximum is then evicted.
|
||||
"""
|
||||
|
||||
class ResultBucketItem:
|
||||
"""
|
||||
An item stored in a :class:`ResultBucket`.
|
||||
"""
|
||||
|
||||
__score: float
|
||||
"""The score associated with this item."""
|
||||
__item: LutItem
|
||||
"""The underlying LutItem."""
|
||||
__seq: int
|
||||
"""
|
||||
Monotonic counter used as a tiebreaker when scores are equal,
|
||||
ensuring that heapq never compares :class:`LutItem` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, score: float, item: LutItem, seq: int):
|
||||
self.__score = score
|
||||
self.__item = item
|
||||
self.__seq = seq
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
"""The score associated with this item."""
|
||||
return self.__score
|
||||
|
||||
@property
|
||||
def item(self) -> LutItem:
|
||||
"""The underlying LutItem."""
|
||||
return self.__item
|
||||
|
||||
def __lt__(self, other: 'ResultBucket.ResultBucketItem') -> bool:
|
||||
# heapq is a min-heap: it always pops the smallest element.
|
||||
# We invert the comparison so that an item with a larger score
|
||||
# is considered "smaller", effectively turning the min-heap
|
||||
# into a max-heap (largest-score item at the top).
|
||||
if self.__score != other.__score:
|
||||
return self.__score > other.__score
|
||||
# Counter tiebreaker: when scores are equal the later-inserted
|
||||
# item (higher seq) is considered "smaller" and gets evicted first.
|
||||
return self.__seq > other.__seq
|
||||
|
||||
__n: int
|
||||
"""Maximum number of items the bucket can hold."""
|
||||
__heap: list[ResultBucketItem]
|
||||
"""
|
||||
Min-heap of :class:`ResultBucketItem`. The heap invariant is inverted
|
||||
via :meth:`ResultBucketItem.__lt__` so the entry with the largest score
|
||||
sits at index 0.
|
||||
"""
|
||||
__counter: int
|
||||
"""
|
||||
Monotonic counter fed to each :class:`ResultBucketItem` as a tiebreaker,
|
||||
preventing heapq from comparing :class:`LutItem` on score collisions.
|
||||
"""
|
||||
|
||||
def __init__(self, n: int):
|
||||
self.__n = n
|
||||
self.__heap = []
|
||||
self.__counter = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__heap)
|
||||
|
||||
def __iter__(self) -> Iterator[LutItem]:
|
||||
for entry in self.__heap:
|
||||
yield entry.item
|
||||
|
||||
def insert(self, item: LutItem, score: float) -> bool:
|
||||
"""
|
||||
Insert a :class:`LutItem` with the given score.
|
||||
|
||||
If the bucket is not yet full the item is always inserted.
|
||||
Otherwise the item is only inserted when *score* is smaller
|
||||
than the largest score currently in the bucket; the entry
|
||||
with the largest score is then evicted.
|
||||
|
||||
:param item: The LutItem to insert.
|
||||
:param score: The score associated with the item.
|
||||
:return: ``True`` if the item was inserted, ``False`` otherwise.
|
||||
"""
|
||||
entry = ResultBucket.ResultBucketItem(score, item, self.__counter)
|
||||
if len(self.__heap) < self.__n:
|
||||
heapq.heappush(self.__heap, entry)
|
||||
self.__counter += 1
|
||||
return True
|
||||
if score >= self.__heap[0].score:
|
||||
return False
|
||||
heapq.heapreplace(self.__heap, entry)
|
||||
self.__counter += 1
|
||||
return True
|
||||
|
||||
|
||||
class LutResolver(Resolver):
|
||||
"""
|
||||
A resolver that uses a lookup table to find the best matching circuit.
|
||||
"""
|
||||
|
||||
__resistor_lut: list[LutItem]
|
||||
"""The lookup table for resistors."""
|
||||
__capacitor_lut: list[LutItem]
|
||||
"""The lookup table for capacitors."""
|
||||
__inductor_lut: list[LutItem]
|
||||
"""The lookup table for inductors."""
|
||||
|
||||
def __init__(self, datasets: DatasetCollection):
|
||||
self.__resistor_lut = LutResolver.__build_lut(
|
||||
datasets.resistor_values, DeviceKind.RESISTOR
|
||||
)
|
||||
self.__capacitor_lut = LutResolver.__build_lut(
|
||||
datasets.capacitor_values, DeviceKind.CAPACITOR
|
||||
)
|
||||
self.__inductor_lut = LutResolver.__build_lut(
|
||||
datasets.inductor_values, DeviceKind.INDUCTOR
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_cache(f: BinaryIO) -> 'LutItem':
|
||||
cnt = _read_int(f)
|
||||
|
||||
if cnt < 1:
|
||||
raise LcrConnException("Invalid circuit count in LUT item")
|
||||
device_value = _read_double(f)
|
||||
circuit = Circuit(device_value)
|
||||
cnt -= 1
|
||||
|
||||
for _ in range(cnt):
|
||||
j = JointKind.SERIES if _read_bool(f) else JointKind.PARALLEL
|
||||
dev = _read_double(f)
|
||||
joint = SubCircuit(j, dev)
|
||||
circuit.add_joint(joint)
|
||||
|
||||
return LutItem(circuit)
|
||||
|
||||
def save_as_cache(self, f: BinaryIO) -> None:
|
||||
_write_int(f, self.circuit.len_devices())
|
||||
_write_double(f, self.circuit.__first_device_value)
|
||||
for joint in self.circuit.joints():
|
||||
_write_bool(f, joint.kind == JointKind.SERIES)
|
||||
_write_double(f, joint.value)
|
||||
|
||||
def compute(self) -> float:
|
||||
"""The computed value of the circuit."""
|
||||
if self.__value_cache is None:
|
||||
self.__value_cache = self.circuit.value()
|
||||
return self.__value_cache
|
||||
|
||||
|
||||
|
||||
DOUBLE_PACKER = struct.Struct("d")
|
||||
INT_PACKER = struct.Struct("I")
|
||||
BOOL_PACKER = struct.Struct("?")
|
||||
|
||||
def _read_double(fs) -> float:
|
||||
return DOUBLE_PACKER.unpack(fs.read(DOUBLE_PACKER.size))[0]
|
||||
|
||||
def _read_int(fs) -> int:
|
||||
return INT_PACKER.unpack(fs.read(INT_PACKER.size))[0]
|
||||
|
||||
def _read_bool(fs) -> bool:
|
||||
return BOOL_PACKER.unpack(fs.read(BOOL_PACKER.size))[0]
|
||||
|
||||
def _write_double(fs, num: float):
|
||||
fs.write(DOUBLE_PACKER.pack(num))
|
||||
|
||||
def _write_int(fs, num: int):
|
||||
fs.write(INT_PACKER.pack(num))
|
||||
|
||||
def _write_bool(fs, num: bool):
|
||||
fs.write(BOOL_PACKER.pack(num))
|
||||
|
||||
|
||||
def __build_lut(dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
|
||||
values = dataset.values
|
||||
joints = tuple(JointKind)
|
||||
return [
|
||||
LutItem(circuit, device_kind)
|
||||
for circuit in chain(
|
||||
(Circuit.from_one_device(v1) for v1 in values),
|
||||
(
|
||||
Circuit.from_two_devices(v1, v2, j2)
|
||||
for v1, v2, j2 in product(values, values, joints)
|
||||
),
|
||||
(
|
||||
Circuit.from_three_devices(v1, v2, j2, v3, j3)
|
||||
for v1, v2, j2, v3, j3 in product(
|
||||
values, values, joints, values, joints
|
||||
)
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
|
||||
# Fetch LUT by device kind
|
||||
lut: list[LutItem]
|
||||
match request.device_kind:
|
||||
case DeviceKind.RESISTOR:
|
||||
lut = self.__resistor_lut
|
||||
case DeviceKind.CAPACITOR:
|
||||
lut = self.__capacitor_lut
|
||||
case DeviceKind.INDUCTOR:
|
||||
lut = self.__inductor_lut
|
||||
|
||||
# Check LUT item one by one
|
||||
bucket = ResultBucket(min(request.count_limit, 100))
|
||||
for item in lut:
|
||||
# compute absolute difference
|
||||
difference = abs(request.target_value - item.value)
|
||||
# If it is out of tolerance, skip it directly.
|
||||
if difference > request.tolerance:
|
||||
continue
|
||||
# put it into bucket
|
||||
bucket.insert(item, difference)
|
||||
|
||||
# Return result
|
||||
return map(lambda item: item.circuit, bucket)
|
||||
|
||||
Reference in New Issue
Block a user