131 lines
2.9 KiB
Python
131 lines
2.9 KiB
Python
|
|
import tomllib
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from enum import StrEnum
|
||
|
|
from typing import Optional
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
|
||
|
|
class DatabaseDriver(StrEnum):
|
||
|
|
SQLITE = "sqlite"
|
||
|
|
MYSQL = "mysql"
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def from_raw(raw: dict):
|
||
|
|
return DatabaseDriver(raw["driver"])
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class SqliteDatabaseConfig:
|
||
|
|
path: str
|
||
|
|
"""Database path"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def from_raw(raw: dict):
|
||
|
|
return SqliteDatabaseConfig(raw["path"])
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class MysqlDatabaseConfig:
|
||
|
|
host: str
|
||
|
|
"""Database host"""
|
||
|
|
port: int
|
||
|
|
"""Database port"""
|
||
|
|
user: str
|
||
|
|
"""Database user"""
|
||
|
|
password: str
|
||
|
|
"""Database password"""
|
||
|
|
database: str
|
||
|
|
"""Database name"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def from_raw(raw: dict):
|
||
|
|
return MysqlDatabaseConfig(
|
||
|
|
raw["host"], raw["port"], raw["user"], raw["password"], raw["database"]
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class DatabaseConfig:
|
||
|
|
driver: DatabaseDriver
|
||
|
|
"""Database driver"""
|
||
|
|
config: SqliteDatabaseConfig | MysqlDatabaseConfig
|
||
|
|
"""Database config"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def from_raw(raw: dict):
|
||
|
|
if raw["driver"] == DatabaseDriver.SQLITE:
|
||
|
|
return DatabaseConfig(
|
||
|
|
DatabaseDriver.SQLITE, SqliteDatabaseConfig.from_raw(raw["config"])
|
||
|
|
)
|
||
|
|
elif raw["driver"] == DatabaseDriver.MYSQL:
|
||
|
|
return DatabaseConfig(
|
||
|
|
DatabaseDriver.MYSQL, MysqlDatabaseConfig.from_raw(raw["config"])
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
raise ValueError("Invalid database driver")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class WebConfig:
|
||
|
|
port: int
|
||
|
|
"""Web server port"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def from_raw(raw: dict):
|
||
|
|
return WebConfig(raw["port"])
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class OthersConfig:
|
||
|
|
debug: bool
|
||
|
|
"""Whether enable debug mode"""
|
||
|
|
auto_token_clean_duration: int
|
||
|
|
"""Auto token clean duration"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def from_raw(raw: dict):
|
||
|
|
return OthersConfig(raw["debug"], raw["auto-token-clean-duration"])
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class Config:
|
||
|
|
database: DatabaseConfig
|
||
|
|
web: WebConfig
|
||
|
|
others: OthersConfig
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def from_raw(raw: dict):
|
||
|
|
return Config(
|
||
|
|
database=DatabaseConfig.from_raw(raw["database"]),
|
||
|
|
web=WebConfig.from_raw(raw["web"]),
|
||
|
|
others=OthersConfig.from_raw(raw["others"]),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
_CONFIG: Optional[Config] = None
|
||
|
|
|
||
|
|
|
||
|
|
def setup_config(p: Path) -> None:
|
||
|
|
"""
|
||
|
|
Setup config by given path.
|
||
|
|
|
||
|
|
Raise exception if config file is invalid.
|
||
|
|
"""
|
||
|
|
with open(p, "rb") as f:
|
||
|
|
raw = tomllib.load(f)
|
||
|
|
|
||
|
|
global _CONFIG
|
||
|
|
_CONFIG = Config.from_raw(raw)
|
||
|
|
|
||
|
|
|
||
|
|
def get_config() -> Config:
|
||
|
|
"""
|
||
|
|
Get config instance.
|
||
|
|
|
||
|
|
Raises RuntimeError if config is not loaded.
|
||
|
|
"""
|
||
|
|
if _CONFIG is None:
|
||
|
|
raise RuntimeError("Config is not loaded. Call setup_config() first.")
|
||
|
|
else:
|
||
|
|
return _CONFIG
|