2024-06-27 20:08:41 +02:00

78 lines
2.8 KiB
Python

import json
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
DB_NAME = 'iottb.db'
class IottbConfig:
def __init__(self, cfg_file):
self.cfg_file = cfg_file
self.default_database = None
self.default_path = None
self.database_locations = {}
self.load_config()
def create_default_config(self):
"""Create default iottb config file."""
logger.info(f'Creating default config file at {self.cfg_file}')
self.default_database = DB_NAME
self.default_path = str(Path.home() / DB_NAME)
self.database_locations = {
DB_NAME: self.default_path
}
defaults = {
'DefaultDatabase': self.default_database,
'DefaultDatabasePath': self.default_path,
'DatabaseLocations': self.database_locations
}
try:
self.cfg_file.parent.mkdir(parents=True, exist_ok=True)
with self.cfg_file.open('w') as config_file:
json.dump(defaults, config_file, indent=4)
except IOError as e:
logger.error(f"Failed to create default configuration file at {self.cfg_file}: {e}")
raise RuntimeError(f"Failed to create configuration file: {e}") from e
def load_config(self):
"""Loads or creates default configuration from given file path."""
if not self.cfg_file.is_file():
self.create_default_config()
else:
with self.cfg_file.open('r') as config_file:
data = json.load(config_file)
self.default_database = data.get('DefaultDatabase')
self.default_path = data.get('DefaultDatabasePath')
self.database_locations = data.get('DatabaseLocations', {})
def save_config(self):
"""Save the current configuration to the config file."""
data = {
'DefaultDatabase': self.default_database,
'DefaultDatabasePath': self.default_path,
'DatabaseLocations': self.database_locations
}
try:
with self.cfg_file.open('w') as config_file:
json.dump(data, config_file, indent=4)
except IOError as e:
logger.error(f"Failed to save configuration file at {self.cfg_file}: {e}")
raise RuntimeError(f"Failed to save configuration file: {e}") from e
def set_default_database(self, name, path):
"""Set the default database and its path."""
self.default_database = name
self.default_path = path
self.database_locations[name] = path
def get_database_location(self, name):
"""Get the location of a specific database."""
return self.database_locations.get(name)
def set_database_location(self, name, path):
"""Set the location for a database."""
self.database_locations[name] = path