1
0

feat: finish lut resolver

This commit is contained in:
2026-06-15 15:58:04 +08:00
parent 7721672b8e
commit ed8f5e1943
6 changed files with 312 additions and 238 deletions

View File

@@ -76,7 +76,8 @@ class SubCircuit:
case JointKind.PARALLEL: case JointKind.PARALLEL:
return (self.__device_value * value) / (self.__device_value + value) return (self.__device_value * value) / (self.__device_value + value)
def get_device_value(self) -> float: @property
def device_value(self) -> float:
""" """
Get the device value Get the device value
@@ -84,7 +85,8 @@ class SubCircuit:
""" """
return self.__device_value return self.__device_value
def get_joint_kind(self) -> JointKind: @property
def joint_kind(self) -> JointKind:
""" """
Get the joint kind Get the joint kind
@@ -93,8 +95,8 @@ class SubCircuit:
return self.__joint_kind return self.__joint_kind
class CircuitDeviceCount(enum.IntEnum): class CircuitDeviceScale(enum.IntEnum):
"""The number of devices in the circuit""" """The scale of devices in the circuit"""
ONE = enum.auto() ONE = enum.auto()
"""One device""" """One device"""
@@ -104,12 +106,17 @@ class CircuitDeviceCount(enum.IntEnum):
"""Three devices""" """Three devices"""
def to_device_count(self) -> int: def to_device_count(self) -> int:
"""
Convert circuit device scale to device count
:return: The device count
"""
match self: match self:
case CircuitDeviceCount.ONE: case CircuitDeviceScale.ONE:
return 1 return 1
case CircuitDeviceCount.TWO: case CircuitDeviceScale.TWO:
return 2 return 2
case CircuitDeviceCount.THREE: case CircuitDeviceScale.THREE:
return 3 return 3
@@ -167,70 +174,6 @@ class Circuit:
SubCircuit(device3_value, device3_joint), SubCircuit(device3_value, device3_joint),
) )
def get_device_count(self) -> CircuitDeviceCount:
if self.__third_device_subckt is not None:
return CircuitDeviceCount.THREE
elif self.__second_device_subckt is not None:
return CircuitDeviceCount.TWO
else:
return CircuitDeviceCount.ONE
def get_first_device_value(self) -> float:
"""
Get the value of the first device
:return: The value of the first device
"""
return self.__first_device_value
def get_second_device_joint(self) -> JointKind:
"""
Get the joint kind of the second device
:return: The joint kind of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.get_joint_kind()
else:
raise LcrConnException("No second device")
def get_second_device_value(self) -> float:
"""
Get the value of the second device
:return: The value of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.get_device_value()
else:
raise LcrConnException("No second device")
def get_third_device_joint(self) -> JointKind:
"""
Get the joint kind of the third device
:return: The joint kind of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.get_joint_kind()
else:
raise LcrConnException("No third device")
def get_third_device_value(self) -> float:
"""
Get the value of the third device
:return: The value of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.get_device_value()
else:
raise LcrConnException("No third device")
def compute(self, device_kind: DeviceKind) -> float: def compute(self, device_kind: DeviceKind) -> float:
""" """
Compute the circuit value Compute the circuit value
@@ -249,3 +192,78 @@ class Circuit:
return value return value
value = self.__third_device_subckt.compute(value, device_kind) value = self.__third_device_subckt.compute(value, device_kind)
return value return value
@property
def device_scale(self) -> CircuitDeviceScale:
"""
Get the device scale
:return: The device scale
"""
if self.__third_device_subckt is not None:
return CircuitDeviceScale.THREE
elif self.__second_device_subckt is not None:
return CircuitDeviceScale.TWO
else:
return CircuitDeviceScale.ONE
@property
def first_device_value(self) -> float:
"""
Get the value of the first device
:return: The value of the first device
"""
return self.__first_device_value
@property
def second_device_joint(self) -> JointKind:
"""
Get the joint kind of the second device
:return: The joint kind of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.joint_kind
else:
raise LcrConnException("No second device")
@property
def second_device_value(self) -> float:
"""
Get the value of the second device
:return: The value of the second device
:raises LcrConnException: If there is no second device
"""
if self.__second_device_subckt is not None:
return self.__second_device_subckt.device_value
else:
raise LcrConnException("No second device")
@property
def third_device_joint(self) -> JointKind:
"""
Get the joint kind of the third device
:return: The joint kind of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.joint_kind
else:
raise LcrConnException("No third device")
@property
def third_device_value(self) -> float:
"""
Get the value of the third device
:return: The value of the third device
:raises LcrConnException: If there is no third device
"""
if self.__third_device_subckt is not None:
return self.__third_device_subckt.device_value
else:
raise LcrConnException("No third device")

View File

@@ -20,7 +20,9 @@ class Dataset:
# Check redundant parts # Check redundant parts
valueset = set(values) valueset = set(values)
if len(valueset) != len(values): if len(valueset) != len(values):
raise LcrConnException(f"Duplicate standard value") raise LcrConnException(f"Duplicate item in standard value list")
if len(valueset) == 0:
raise LcrConnException(f"Empty standard value list is not allowed")
# Ok, assign it # Ok, assign it
self.__values = values self.__values = values
@@ -45,7 +47,8 @@ class Dataset:
legal_lines = filter(lambda line: line != "", (line.strip() for line in f)) legal_lines = filter(lambda line: line != "", (line.strip() for line in f))
return Dataset.from_iterable(legal_lines) return Dataset.from_iterable(legal_lines)
def get_values(self) -> tuple[float, ...]: @property
def values(self) -> tuple[float, ...]:
""" """
Get the available standard values Get the available standard values
@@ -99,29 +102,32 @@ class DatasetCollection:
Dataset.from_file(inductor), Dataset.from_file(inductor),
) )
def get_resistor_values(self) -> tuple[float, ...]: @property
def resistor_values(self) -> Dataset:
""" """
Get the available standard values for resistor Get the available standard values for resistor
:return: A tuple of available standard values for resistor :return: A tuple of available standard values for resistor
""" """
return self.__resistor.get_values() return self.__resistor
def get_capacitor_values(self) -> tuple[float, ...]: @property
def capacitor_values(self) -> Dataset:
""" """
Get the available standard values for capacitor Get the available standard values for capacitor
:return: A tuple of available standard values for capacitor :return: A tuple of available standard values for capacitor
""" """
return self.__capacitor.get_values() return self.__capacitor
def get_inductor_values(self) -> tuple[float, ...]: @property
def inductor_values(self) -> Dataset:
""" """
Get the available standard values for inductor Get the available standard values for inductor
:return: A tuple of available standard values for inductor :return: A tuple of available standard values for inductor
""" """
return self.__inductor.get_values() return self.__inductor
def from_human_readable_value(strl: str) -> float: def from_human_readable_value(strl: str) -> float:

View File

@@ -0,0 +1,11 @@
from .common import Resolver, ResultPriority, ResolverRequest
from .lut import LutResolver
from .astar import AStarResolver
__all__ = [
'Resolver',
'ResultPriority',
'ResolverRequest',
'LutResolver',
'AStarResolver'
]

View File

@@ -1,6 +1,7 @@
from typing import Iterator from typing import Iterator
from .common import Resolver, ResolverRequest, ResolverResult, ResultPriority from .common import Resolver, ResolverRequest, ResultPriority
from ..dataset import DatasetCollection from ..dataset import DatasetCollection
from ..common import Circuit
class AStarResolver(Resolver): class AStarResolver(Resolver):
""" """
@@ -11,5 +12,5 @@ class AStarResolver(Resolver):
pass pass
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]: def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
pass pass

View File

@@ -19,7 +19,7 @@ class ResultPriority(enum.Enum):
@dataclass @dataclass
class ResolverRequest: class ResolverRequest:
""" """
The request object for the resolver. All request infomation for the resolver.
""" """
device_kind: DeviceKind device_kind: DeviceKind
@@ -27,69 +27,11 @@ class ResolverRequest:
target_value: float target_value: float
"""The target value of the device.""" """The target value of the device."""
tolerance: float tolerance: float
"""The tolerance of the device.""" """The tolerance of the device in absolute value."""
result_priority: ResultPriority result_priority: ResultPriority
"""The priority of the result.""" """The priority of the result."""
count_limit: int count_limit: int
"""The limit of the count of results.""" """The limited count of results."""
class ResolverResult:
"""
The result of the resolver.
"""
circuit: Circuit
"""The circuit of the result."""
__value_cache: float | None
"""The cache of the circuit value."""
__difference_cache: float | None
"""The cache of the difference between the target value and the circuit value."""
__relative_difference_cache: float | None
"""The cache of the relative difference between the target value and the circuit value."""
def __init__(self, circuit: Circuit):
self.circuit = circuit
self.__value_cache = None
self.__difference_cache = None
self.__relative_difference_cache = None
def compute(self, device_kind: DeviceKind) -> float:
"""
Compute the circuit value.
"""
if self.__value_cache is None:
self.__value_cache = self.circuit.compute(device_kind)
return self.__value_cache
def difference(self, target_value: float, device_kind: DeviceKind) -> float:
"""
Get the difference between the target value and the circuit value.
"""
if self.__difference_cache is None:
self.__difference_cache = abs(
target_value - self.circuit.compute(device_kind)
)
return self.__difference_cache
def relative_difference(
self, target_value: float, device_kind: DeviceKind
) -> float:
"""
Get the relative difference between the target value and the circuit value.
"""
if self.__relative_difference_cache is None:
self.__relative_difference_cache = (
abs(target_value - self.circuit.compute(device_kind)) / target_value
)
return self.__relative_difference_cache
def len_devices(self) -> int:
"""
Get the number of devices in the circuit.
"""
return self.circuit.len_devices()
class Resolver(ABC): class Resolver(ABC):
@@ -98,5 +40,5 @@ class Resolver(ABC):
""" """
@abstractmethod @abstractmethod
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]: def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
pass pass

View File

@@ -1,38 +1,10 @@
import struct import heapq
from typing import Iterator, BinaryIO from itertools import chain, product
from pathlib import Path from typing import Iterable, Iterator
from .common import Resolver, ResolverRequest, ResolverResult, ResultPriority from functools import cached_property
from ..dataset import DatasetCollection from .common import Resolver, ResolverRequest, ResultPriority
from ..common import Circuit, SubCircuit, JointKind, LcrConnException from ..dataset import DatasetCollection, Dataset
from ..common import Circuit, DeviceKind, JointKind, LcrConnException
class LutResolver(Resolver):
"""
A resolver that uses a lookup table to find the best matching circuit.
"""
lut: tuple[Circuit]
def __init__(self, lut: tuple[Circuit]):
self.lut = lut
@staticmethod
def from_dataset(dataset: DatasetCollection) -> 'LutResolver':
pass
@staticmethod
def from_cache(filename: Path) -> 'LutResolver':
with open(filename, "rb") as f:
cnt = _read_int(f)
return LutResolver(tuple(LutItem.from_cache(f) for _ in range(cnt)))
def save_as_cache(self, filename: Path) -> None:
with open(filename, "wb") as f:
_write_int(f, len(self.lut))
for item in self.lut:
item.save_as_cache(f)
def resolve(self, request: ResolverRequest) -> Iterator[ResolverResult]:
pass
class LutItem: class LutItem:
@@ -40,70 +12,194 @@ class LutItem:
An item in the lookup table. An item in the lookup table.
""" """
circuit: Circuit __circuit: Circuit
"""The circuit represented by this item.""" """The circuit represented by this item."""
__value_cache: float | None __device_kind: DeviceKind
"""The cached computed value of the circuit, or None if it has not been cached yet.""" """The device kind applied for this circuit."""
def __init__(self, circuit: Circuit): def __init__(self, circuit: Circuit, device_kind: DeviceKind):
self.circuit = circuit self.__circuit = circuit
self.__device_kind = device_kind
@property
def circuit(self) -> Circuit:
return self.__circuit
@cached_property
def value(self) -> float:
"""
The computed value of the circuit.
:return: The computed value.
"""
return self.__circuit.compute(self.__device_kind)
class ResultBucket(Iterable[LutItem]):
"""
A bounded bucket that keeps up to `N` LutItem entries with the smallest floats.
When the bucket is full, inserting a new item only succeeds if its float
is less than the current maximum; the maximum is then evicted.
"""
class ResultBucketItem:
"""
An item stored in a :class:`ResultBucket`.
"""
__score: float
"""The score associated with this item."""
__item: LutItem
"""The underlying LutItem."""
__seq: int
"""
Monotonic counter used as a tiebreaker when scores are equal,
ensuring that heapq never compares :class:`LutItem` directly.
"""
def __init__(self, score: float, item: LutItem, seq: int):
self.__score = score
self.__item = item
self.__seq = seq
@property
def score(self) -> float:
"""The score associated with this item."""
return self.__score
@property
def item(self) -> LutItem:
"""The underlying LutItem."""
return self.__item
def __lt__(self, other: 'ResultBucket.ResultBucketItem') -> bool:
# heapq is a min-heap: it always pops the smallest element.
# We invert the comparison so that an item with a larger score
# is considered "smaller", effectively turning the min-heap
# into a max-heap (largest-score item at the top).
if self.__score != other.__score:
return self.__score > other.__score
# Counter tiebreaker: when scores are equal the later-inserted
# item (higher seq) is considered "smaller" and gets evicted first.
return self.__seq > other.__seq
__n: int
"""Maximum number of items the bucket can hold."""
__heap: list[ResultBucketItem]
"""
Min-heap of :class:`ResultBucketItem`. The heap invariant is inverted
via :meth:`ResultBucketItem.__lt__` so the entry with the largest score
sits at index 0.
"""
__counter: int
"""
Monotonic counter fed to each :class:`ResultBucketItem` as a tiebreaker,
preventing heapq from comparing :class:`LutItem` on score collisions.
"""
def __init__(self, n: int):
self.__n = n
self.__heap = []
self.__counter = 0
def __len__(self) -> int:
return len(self.__heap)
def __iter__(self) -> Iterator[LutItem]:
for entry in self.__heap:
yield entry.item
def insert(self, item: LutItem, score: float) -> bool:
"""
Insert a :class:`LutItem` with the given score.
If the bucket is not yet full the item is always inserted.
Otherwise the item is only inserted when *score* is smaller
than the largest score currently in the bucket; the entry
with the largest score is then evicted.
:param item: The LutItem to insert.
:param score: The score associated with the item.
:return: ``True`` if the item was inserted, ``False`` otherwise.
"""
entry = ResultBucket.ResultBucketItem(score, item, self.__counter)
if len(self.__heap) < self.__n:
heapq.heappush(self.__heap, entry)
self.__counter += 1
return True
if score >= self.__heap[0].score:
return False
heapq.heapreplace(self.__heap, entry)
self.__counter += 1
return True
class LutResolver(Resolver):
"""
A resolver that uses a lookup table to find the best matching circuit.
"""
__resistor_lut: list[LutItem]
"""The lookup table for resistors."""
__capacitor_lut: list[LutItem]
"""The lookup table for capacitors."""
__inductor_lut: list[LutItem]
"""The lookup table for inductors."""
def __init__(self, datasets: DatasetCollection):
self.__resistor_lut = LutResolver.__build_lut(
datasets.resistor_values, DeviceKind.RESISTOR
)
self.__capacitor_lut = LutResolver.__build_lut(
datasets.capacitor_values, DeviceKind.CAPACITOR
)
self.__inductor_lut = LutResolver.__build_lut(
datasets.inductor_values, DeviceKind.INDUCTOR
)
@staticmethod @staticmethod
def from_cache(f: BinaryIO) -> 'LutItem': def __build_lut(dataset: Dataset, device_kind: DeviceKind) -> list[LutItem]:
cnt = _read_int(f) values = dataset.values
joints = tuple(JointKind)
if cnt < 1: return [
raise LcrConnException("Invalid circuit count in LUT item") LutItem(circuit, device_kind)
device_value = _read_double(f) for circuit in chain(
circuit = Circuit(device_value) (Circuit.from_one_device(v1) for v1 in values),
cnt -= 1 (
Circuit.from_two_devices(v1, v2, j2)
for _ in range(cnt): for v1, v2, j2 in product(values, values, joints)
j = JointKind.SERIES if _read_bool(f) else JointKind.PARALLEL ),
dev = _read_double(f) (
joint = SubCircuit(j, dev) Circuit.from_three_devices(v1, v2, j2, v3, j3)
circuit.add_joint(joint) for v1, v2, j2, v3, j3 in product(
values, values, joints, values, joints
return LutItem(circuit) )
),
def save_as_cache(self, f: BinaryIO) -> None: )
_write_int(f, self.circuit.len_devices()) ]
_write_double(f, self.circuit.__first_device_value)
for joint in self.circuit.joints():
_write_bool(f, joint.kind == JointKind.SERIES)
_write_double(f, joint.value)
def compute(self) -> float:
"""The computed value of the circuit."""
if self.__value_cache is None:
self.__value_cache = self.circuit.value()
return self.__value_cache
DOUBLE_PACKER = struct.Struct("d")
INT_PACKER = struct.Struct("I")
BOOL_PACKER = struct.Struct("?")
def _read_double(fs) -> float:
return DOUBLE_PACKER.unpack(fs.read(DOUBLE_PACKER.size))[0]
def _read_int(fs) -> int:
return INT_PACKER.unpack(fs.read(INT_PACKER.size))[0]
def _read_bool(fs) -> bool:
return BOOL_PACKER.unpack(fs.read(BOOL_PACKER.size))[0]
def _write_double(fs, num: float):
fs.write(DOUBLE_PACKER.pack(num))
def _write_int(fs, num: int):
fs.write(INT_PACKER.pack(num))
def _write_bool(fs, num: bool):
fs.write(BOOL_PACKER.pack(num))
def resolve(self, request: ResolverRequest) -> Iterator[Circuit]:
# Fetch LUT by device kind
lut: list[LutItem]
match request.device_kind:
case DeviceKind.RESISTOR:
lut = self.__resistor_lut
case DeviceKind.CAPACITOR:
lut = self.__capacitor_lut
case DeviceKind.INDUCTOR:
lut = self.__inductor_lut
# Check LUT item one by one
bucket = ResultBucket(min(request.count_limit, 100))
for item in lut:
# compute absolute difference
difference = abs(request.target_value - item.value)
# If it is out of tolerance, skip it directly.
if difference > request.tolerance:
continue
# put it into bucket
bucket.insert(item, difference)
# Return result
return map(lambda item: item.circuit, bucket)