1
0

feat: finish lut resolver

This commit is contained in:
2026-06-15 15:58:04 +08:00
parent 7721672b8e
commit ed8f5e1943
6 changed files with 312 additions and 238 deletions

View File

@@ -76,7 +76,8 @@ class SubCircuit:
case JointKind.PARALLEL:
return (self.__device_value * value) / (self.__device_value + value)
def get_device_value(self) -> float:
@property
def device_value(self) -> float:
"""
Get the device value
@@ -84,7 +85,8 @@ class SubCircuit:
"""
return self.__device_value
def get_joint_kind(self) -> JointKind:
@property
def joint_kind(self) -> JointKind:
"""
Get the joint kind
@@ -93,8 +95,8 @@ class SubCircuit:
return self.__joint_kind
class CircuitDeviceCount(enum.IntEnum):
"""The number of devices in the circuit"""
class CircuitDeviceScale(enum.IntEnum):
"""The scale of devices in the circuit"""
ONE = enum.auto()
"""One device"""
@@ -104,12 +106,17 @@ class CircuitDeviceCount(enum.IntEnum):
"""Three devices"""
def to_device_count(self) -> int:
"""
Convert circuit device scale to device count
:return: The device count
"""
match self:
case CircuitDeviceCount.ONE:
case CircuitDeviceScale.ONE:
return 1
case CircuitDeviceCount.TWO:
case CircuitDeviceScale.TWO:
return 2
case CircuitDeviceCount.THREE:
case CircuitDeviceScale.THREE:
return 3
@@ -167,70 +174,6 @@ class Circuit:
SubCircuit(device3_value, device3_joint),
)
def get_device_count(self) -> CircuitDeviceCount:
if self.__third_device_subckt is not None:
return CircuitDeviceCount.THREE
elif self.__second_device_subckt is not None:
return CircuitDeviceCount.TWO
else:
return CircuitDeviceCount.ONE
def get_first_device_value(self) -> float:
"""
Get the value of the first device
:return: The value of the first device
"""
return self.__first_device_value
def get_second_device_joint(self) -> JointKind:
"""
Get the joint kind of the second device
:return: The joint kind of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.get_joint_kind()
else:
raise LcrConnException("No second device")
def get_second_device_value(self) -> float:
"""
Get the value of the second device
:return: The value of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.get_device_value()
else:
raise LcrConnException("No second device")
def get_third_device_joint(self) -> JointKind:
"""
Get the joint kind of the third device
:return: The joint kind of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.get_joint_kind()
else:
raise LcrConnException("No third device")
def get_third_device_value(self) -> float:
"""
Get the value of the third device
:return: The value of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.get_device_value()
else:
raise LcrConnException("No third device")
def compute(self, device_kind: DeviceKind) -> float:
"""
Compute the circuit value
@@ -249,3 +192,78 @@ class Circuit:
return value
value = self.__third_device_subckt.compute(value, device_kind)
return value
@property
def device_scale(self) -> CircuitDeviceScale:
"""
Get the device scale
:return: The device scale
"""
if self.__third_device_subckt is not None:
return CircuitDeviceScale.THREE
elif self.__second_device_subckt is not None:
return CircuitDeviceScale.TWO
else:
return CircuitDeviceScale.ONE
@property
def first_device_value(self) -> float:
"""
Get the value of the first device
:return: The value of the first device
"""
return self.__first_device_value
@property
def second_device_joint(self) -> JointKind:
"""
Get the joint kind of the second device
:return: The joint kind of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.joint_kind
else:
raise LcrConnException("No second device")
@property
def second_device_value(self) -> float:
"""
Get the value of the second device
:return: The value of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.device_value
else:
raise LcrConnException("No second device")
@property
def third_device_joint(self) -> JointKind:
"""
Get the joint kind of the third device
:return: The joint kind of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.joint_kind
else:
raise LcrConnException("No third device")
@property
def third_device_value(self) -> float:
"""
Get the value of the third device
:return: The value of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.device_value
else:
raise LcrConnException("No third device")

View File

@@ -20,7 +20,9 @@ class Dataset:
# Check redundant parts
valueset = set(values)
if len(valueset) != len(values):
raise LcrConnException(f"Duplicate standard value")
raise LcrConnException(f"Duplicate item in standard value list")
if len(valueset) == 0:
raise LcrConnException(f"Empty standard value list is not allowed")
# Ok, assign it
self.__values = values
@@ -45,7 +47,8 @@ class Dataset:
legal_lines = filter(lambda line: line != "", (line.strip() for line in f))
return Dataset.from_iterable(legal_lines)
def get_values(self) -> tuple[float, ...]:
@property
def values(self) -> tuple[float, ...]:
"""
Get the available standard values
@@ -99,29 +102,32 @@ class DatasetCollection:
Dataset.from_file(inductor),
)
def get_resistor_values(self) -> tuple[float, ...]:
@property
def resistor_values(self) -> Dataset:
"""
Get the available standard values for resistor
:return: A tuple of available standard values for resistor
"""
return self.__resistor.get_values()
return self.__resistor
def get_capacitor_values(self) -> tuple[float, ...]:
@property
def capacitor_values(self) -> Dataset:
"""
Get the available standard values for capacitor
:return: A tuple of available standard values for capacitor
"""
return self.__capacitor.get_values()
return self.__capacitor
def get_inductor_values(self) -> tuple[float, ...]:
@property
def inductor_values(self) -> Dataset:
"""
Get the available standard values for inductor
:return: A tuple of available standard values for inductor
"""
return self.__inductor.get_values()
return self.__inductor
def from_human_readable_value(strl: str) -> float:

View File

@@ -0,0 +1,11 @@
from .common import Resolver, ResultPriority, ResolverRequest
from .lut import LutResolver
from .astar import AStarResolver
__all__ = [
'Resolver',
'ResultPriority',
'ResolverRequest',
'LutResolver',
'AStarResolver'
]

View File

@@ -1,6 +1,7 @@
from typing import Iterator
from .common import Resolver, ResolverRequest, ResolverResult, ResultPriority
from .common import Resolver, ResolverRequest, ResultPriority
from ..dataset import DatasetCollection
from ..common import Circuit
class AStarResolver(Resolver):
"""
@@ -11,5 +12,5 @@ class AStarResolver(Resolver):
pass
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]:
def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
pass

View File

@@ -19,7 +19,7 @@ class ResultPriority(enum.Enum):
@dataclass
class ResolverRequest:
"""
The request object for the resolver.
All request infomation for the resolver.
"""
device_kind: DeviceKind
@@ -27,69 +27,11 @@ class ResolverRequest:
target_value: float
"""The target value of the device."""
tolerance: float
"""The tolerance of the device."""
"""The tolerance of the device in absolute value."""
result_priority: ResultPriority
"""The priority of the result."""
count_limit: int
"""The limit of the count of results."""
class ResolverResult:
"""
The result of the resolver.
"""
circuit: Circuit
"""The circuit of the result."""
__value_cache: float | None
"""The cache of the circuit value."""
__difference_cache: float | None
"""The cache of the difference between the target value and the circuit value."""
__relative_difference_cache: float | None
"""The cache of the relative difference between the target value and the circuit value."""
def __init__(self, circuit: Circuit):
self.circuit = circuit
self.__value_cache = None
self.__difference_cache = None
self.__relative_difference_cache = None
def compute(self, device_kind: DeviceKind) -> float:
"""
Compute the circuit value.
"""
if self.__value_cache is None:
self.__value_cache = self.circuit.compute(device_kind)
return self.__value_cache
def difference(self, target_value: float, device_kind: DeviceKind) -> float:
"""
Get the difference between the target value and the circuit value.
"""
if self.__difference_cache is None:
self.__difference_cache = abs(
target_value - self.circuit.compute(device_kind)
)
return self.__difference_cache
def relative_difference(
self, target_value: float, device_kind: DeviceKind
) -> float:
"""
Get the relative difference between the target value and the circuit value.
"""
if self.__relative_difference_cache is None:
self.__relative_difference_cache = (
abs(target_value - self.circuit.compute(device_kind)) / target_value
)
return self.__relative_difference_cache
def len_devices(self) -> int:
"""
Get the number of devices in the circuit.
"""
return self.circuit.len_devices()
"""The limited count of results."""
class Resolver(ABC):
@@ -98,5 +40,5 @@ class Resolver(ABC):
"""
@abstractmethod
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]:
def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
pass

View File

@@ -1,38 +1,10 @@
import struct
from typing import Iterator, BinaryIO
from pathlib import Path
from .common import Resolver, ResolverRequest, ResolverResult, ResultPriority
from ..dataset import DatasetCollection
from ..common import Circuit, SubCircuit, JointKind, LcrConnException
class LutResolver(Resolver):
"""
A resolver that uses a lookup table to find the best matching circuit.
"""
lut: tuple[Circuit]
def __init__(self, lut: tuple[Circuit]):
self.lut = lut
@staticmethod
def from_dataset(dataset: DatasetCollection) -> 'LutResolver':
pass
@staticmethod
def from_cache(filename: Path) -> 'LutResolver':
with open(filename, "rb") as f:
cnt = _read_int(f)
return LutResolver(tuple(LutItem.from_cache(f) for _ in range(cnt)))
def save_as_cache(self, filename: Path) -> None:
with open(filename, "wb") as f:
_write_int(f, len(self.lut))
for item in self.lut:
item.save_as_cache(f)
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]:
pass
import heapq
from itertools import chain, product
from typing import Iterable, Iterator
from functools import cached_property
from .common import Resolver, ResolverRequest, ResultPriority
from ..dataset import DatasetCollection, Dataset
from ..common import Circuit, DeviceKind, JointKind, LcrConnException
class LutItem:
@@ -40,70 +12,194 @@ class LutItem:
An item in the lookup table.
"""
circuit: Circuit
__circuit: Circuit
"""The circuit represented by this item."""
__value_cache: float | None
"""The cached computed value of the circuit, or None if it has not been cached yet."""
__device_kind: DeviceKind
"""The device kind applied for this circuit."""
def __init__(self, circuit: Circuit):
self.circuit = circuit
def __init__(self, circuit: Circuit, device_kind: DeviceKind):
self.__circuit = circuit
self.__device_kind = device_kind
@property
def circuit(self) -> Circuit:
return self.__circuit
@cached_property
def value(self) -> float:
"""
The computed value of the circuit.
:return: The computed value.
"""
return self.__circuit.compute(self.__device_kind)
class ResultBucket(Iterable[LutItem]):
"""
A bounded bucket that keeps up to `N` LutItem entries with the smallest floats.
When the bucket is full, inserting a new item only succeeds if its float
is less than the current maximum; the maximum is then evicted.
"""
class ResultBucketItem:
"""
An item stored in a :class:`ResultBucket`.
"""
__score: float
"""The score associated with this item."""
__item: LutItem
"""The underlying LutItem."""
__seq: int
"""
Monotonic counter used as a tiebreaker when scores are equal,
ensuring that heapq never compares :class:`LutItem` directly.
"""
def __init__(self, score: float, item: LutItem, seq: int):
self.__score = score
self.__item = item
self.__seq = seq
@property
def score(self) -> float:
"""The score associated with this item."""
return self.__score
@property
def item(self) -> LutItem:
"""The underlying LutItem."""
return self.__item
def __lt__(self, other: 'ResultBucket.ResultBucketItem') -> bool:
# heapq is a min-heap: it always pops the smallest element.
# We invert the comparison so that an item with a larger score
# is considered "smaller", effectively turning the min-heap
# into a max-heap (largest-score item at the top).
if self.__score != other.__score:
return self.__score > other.__score
# Counter tiebreaker: when scores are equal the later-inserted
# item (higher seq) is considered "smaller" and gets evicted first.
return self.__seq > other.__seq
__n: int
"""Maximum number of items the bucket can hold."""
__heap: list[ResultBucketItem]
"""
Min-heap of :class:`ResultBucketItem`. The heap invariant is inverted
via :meth:`ResultBucketItem.__lt__` so the entry with the largest score
sits at index 0.
"""
__counter: int
"""
Monotonic counter fed to each :class:`ResultBucketItem` as a tiebreaker,
preventing heapq from comparing :class:`LutItem` on score collisions.
"""
def __init__(self, n: int):
self.__n = n
self.__heap = []
self.__counter = 0
def __len__(self) -> int:
return len(self.__heap)
def __iter__(self) -> Iterator[LutItem]:
for entry in self.__heap:
yield entry.item
def insert(self, item: LutItem, score: float) -> bool:
"""
Insert a :class:`LutItem` with the given score.
If the bucket is not yet full the item is always inserted.
Otherwise the item is only inserted when *score* is smaller
than the largest score currently in the bucket; the entry
with the largest score is then evicted.
:param item: The LutItem to insert.
:param score: The score associated with the item.
:return: ``True`` if the item was inserted, ``False`` otherwise.
"""
entry = ResultBucket.ResultBucketItem(score, item, self.__counter)
if len(self.__heap) < self.__n:
heapq.heappush(self.__heap, entry)
self.__counter += 1
return True
if score >= self.__heap[0].score:
return False
heapq.heapreplace(self.__heap, entry)
self.__counter += 1
return True
class LutResolver(Resolver):
"""
A resolver that uses a lookup table to find the best matching circuit.
"""
__resistor_lut: list[LutItem]
"""The lookup table for resistors."""
__capacitor_lut: list[LutItem]
"""The lookup table for capacitors."""
__inductor_lut: list[LutItem]
"""The lookup table for inductors."""
def __init__(self, datasets: DatasetCollection):
self.__resistor_lut = LutResolver.__build_lut(
datasets.resistor_values, DeviceKind.RESISTOR
)
self.__capacitor_lut = LutResolver.__build_lut(
datasets.capacitor_values, DeviceKind.CAPACITOR
)
self.__inductor_lut = LutResolver.__build_lut(
datasets.inductor_values, DeviceKind.INDUCTOR
)
@staticmethod
def from_cache(f: BinaryIO) -> 'LutItem':
cnt = _read_int(f)
if cnt < 1:
raise LcrConnException("Invalid circuit count in LUT item")
device_value = _read_double(f)
circuit = Circuit(device_value)
cnt -= 1
for _ in range(cnt):
j = JointKind.SERIES if _read_bool(f) else JointKind.PARALLEL
dev = _read_double(f)
joint = SubCircuit(j, dev)
circuit.add_joint(joint)
return LutItem(circuit)
def save_as_cache(self, f: BinaryIO) -> None:
_write_int(f, self.circuit.len_devices())
_write_double(f, self.circuit.__first_device_value)
for joint in self.circuit.joints():
_write_bool(f, joint.kind == JointKind.SERIES)
_write_double(f, joint.value)
def compute(self) -> float:
"""The computed value of the circuit."""
if self.__value_cache is None:
self.__value_cache = self.circuit.value()
return self.__value_cache
DOUBLE_PACKER = struct.Struct("d")
INT_PACKER = struct.Struct("I")
BOOL_PACKER = struct.Struct("?")
def _read_double(fs) -> float:
return DOUBLE_PACKER.unpack(fs.read(DOUBLE_PACKER.size))[0]
def _read_int(fs) -> int:
return INT_PACKER.unpack(fs.read(INT_PACKER.size))[0]
def _read_bool(fs) -> bool:
return BOOL_PACKER.unpack(fs.read(BOOL_PACKER.size))[0]
def _write_double(fs, num: float):
fs.write(DOUBLE_PACKER.pack(num))
def _write_int(fs, num: int):
fs.write(INT_PACKER.pack(num))
def _write_bool(fs, num: bool):
fs.write(BOOL_PACKER.pack(num))
def __build_lut(dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
values = dataset.values
joints = tuple(JointKind)
return [
LutItem(circuit, device_kind)
for circuit in chain(
(Circuit.from_one_device(v1) for v1 in values),
(
Circuit.from_two_devices(v1, v2, j2)
for v1, v2, j2 in product(values, values, joints)
),
(
Circuit.from_three_devices(v1, v2, j2, v3, j3)
for v1, v2, j2, v3, j3 in product(
values, values, joints, values, joints
)
),
)
]
def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
# Fetch LUT by device kind
lut: list[LutItem]
match request.device_kind:
case DeviceKind.RESISTOR:
lut = self.__resistor_lut
case DeviceKind.CAPACITOR:
lut = self.__capacitor_lut
case DeviceKind.INDUCTOR:
lut = self.__inductor_lut
# Check LUT item one by one
bucket = ResultBucket(min(request.count_limit, 100))
for item in lut:
# compute absolute difference
difference = abs(request.target_value - item.value)
# If it is out of tolerance, skip it directly.
if difference > request.tolerance:
continue
# put it into bucket
bucket.insert(item, difference)
# Return result
return map(lambda item: item.circuit, bucket)