From 92e54174cce62d962825cc0e46831c1921f132dd Mon Sep 17 00:00:00 2001 From: Sebastian Lenzlinger <74497638+sebaschi@users.noreply.github.com> Date: Thu, 4 Jan 2024 15:47:00 +0100 Subject: [PATCH] Add db_connector.py body and beginnings of map.py. As well as config_example.py to use as central source of truth for connection credentials. --- analysis/config_example.py | 11 +++++++ analysis/db_connector.py | 65 ++++++++++++++++++++++++++++++++++++-- analysis/map.py | 43 +++++++++++++++++++++++++ analysis/plots.py | 0 requirements.txt | 2 ++ 5 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 analysis/config_example.py create mode 100644 analysis/plots.py diff --git a/analysis/config_example.py b/analysis/config_example.py new file mode 100644 index 0000000..830fe5e --- /dev/null +++ b/analysis/config_example.py @@ -0,0 +1,11 @@ +# config.py, adjust as needed +# TODO RENAME THIS FILE TO "config.py" +SSH_HOST = 'slenzlinger.dev' +SSH_USERNAME = 'sebl' #TODO: Enter own username +SSH_PASSWORD = 'your_ssh_password' # TODO: to not push to git +DB_NAME = 'your_database_name' # TODO +DB_USER = 'your_database_username' # TODO +DB_PASSWORD = 'your_database_password' # TODO +DB_HOST = 'your_database_host' # TODO +DB_PORT = 5433 +SSH_PORT = 22 diff --git a/analysis/db_connector.py b/analysis/db_connector.py index e6d1486..90dc740 100644 --- a/analysis/db_connector.py +++ b/analysis/db_connector.py @@ -1,2 +1,63 @@ -import sqlalchemy as db -import pandas as pd +import logging +from sshtunnel import SSHTunnelForwarder +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from config import SSH_HOST, SSH_USERNAME, SSH_PASSWORD, DB_NAME, DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, SSH_PORT + +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger('db_connector.py') +stream_handler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +stream_handler.setFormatter(formatter) +logger.addHandler(stream_handler) + +class RemoteDB: + def __init__(self): + self.ssh_host = SSH_HOST + self.ssh_username = SSH_USERNAME + self.ssh_password = SSH_PASSWORD + self.db_name = DB_NAME + self.db_user = DB_USER + self.db_password = DB_PASSWORD + self.db_host = DB_HOST + self.db_port = DB_PORT + self.ssh_port = SSH_PORT + self.tunnel = None + self.engine = None + self.Session = None + self._connect() + + def _connect(self): + try: + self.tunnel = SSHTunnelForwarder( + (self.ssh_host, self.ssh_port), + ssh_username=self.ssh_username, + ssh_password=self.ssh_password, + remote_bind_address=(self.db_host, self.db_port) + ) + self.tunnel.start() + + local_port = self.tunnel.local_bind_port + db_url = f"postgresql://{self.db_user}:{self.db_password}@localhost:{local_port}/{self.db_name}" + self.engine = create_engine(db_url) + self.Session = sessionmaker(bind=self.engine) + except Exception as e: + logger.exception(f"Connection failed: {e}") + + def execute_query(self, query): + session = self.Session() + try: + result = session.execute(query) + session.commit() + return result.fetchall() + except Exception as e: + session.rollback() + raise + finally: + session.close() + + def close(self): + if self.engine: + self.engine.dispose() + if self.tunnel: + self.tunnel.stop() diff --git a/analysis/map.py b/analysis/map.py index e69de29..e6cebca 100644 --- a/analysis/map.py +++ b/analysis/map.py @@ -0,0 +1,43 @@ +import pandas as pd +import geopandas as gpd +import os +import folium +from folium import plugins +import logging + +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger('map.py') +stream_handler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +stream_handler.setFormatter(formatter) +logger.addHandler(stream_handler) + + +accidents_filepath = "../src/datasets/integrated/Accidents.geojson" +signaled_speeds_filepath = "../src/datasets/integrated/signaled_speeds.geojson.geojson" + +# Map centered around zurich +zurich_coordinates = [47.368650, 8.539183] +fixed_map_zurich_original_coords = folium.Map( + location=zurich_coordinates, + zoom_start=13, + zoom_control=False, + dragging=False, + scrollWheelZoom=False, + doubleClickZoom=False +) +def create_acc_map(): + acc_gdf = gpd.read_file(accidents_filepath) + acc_gdf['latitude'] = acc_gdf.geometry.y + acc_gdf['longitude'] = acc_gdf.geometry.x + + + # Ensure we're dealing with floats + acc_gdf['latitude'] = acc_gdf['latitude'].astype(float) + acc_gdf['longitude'] = acc_gdf['longitude'].astype(float) + + # Build heat dataframe used for mapping + heat_df = acc_gdf + heat_df = heat_df[['latitude', 'longitude']] + heat_df = heat_df.dropna(axis=0, subset=['latitude', 'longitude']) + heat_df = heat_df.dropna(axis=0, subset=['latitude', 'longitude']) diff --git a/analysis/plots.py b/analysis/plots.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index 1681400..be08387 100755 --- a/requirements.txt +++ b/requirements.txt @@ -135,3 +135,5 @@ Werkzeug==3.0.1 widgetsnbextension==4.0.9 xyzservices==2023.10.1 zipp==3.17.0 + +sshtunnel~=0.4.0 \ No newline at end of file