Add db_connector.py body and beginnings of map.py. As well as config_example.py to use as central source of truth for connection credentials.
This commit is contained in:
parent
ab1baacf47
commit
92e54174cc
11
analysis/config_example.py
Normal file
11
analysis/config_example.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# config.py, adjust as needed
|
||||||
|
# TODO RENAME THIS FILE TO "config.py"
|
||||||
|
SSH_HOST = 'slenzlinger.dev'
|
||||||
|
SSH_USERNAME = 'sebl' #TODO: Enter own username
|
||||||
|
SSH_PASSWORD = 'your_ssh_password' # TODO: to not push to git
|
||||||
|
DB_NAME = 'your_database_name' # TODO
|
||||||
|
DB_USER = 'your_database_username' # TODO
|
||||||
|
DB_PASSWORD = 'your_database_password' # TODO
|
||||||
|
DB_HOST = 'your_database_host' # TODO
|
||||||
|
DB_PORT = 5433
|
||||||
|
SSH_PORT = 22
|
||||||
@ -1,2 +1,63 @@
|
|||||||
import sqlalchemy as db
|
import logging
|
||||||
import pandas as pd
|
from sshtunnel import SSHTunnelForwarder
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from config import SSH_HOST, SSH_USERNAME, SSH_PASSWORD, DB_NAME, DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, SSH_PORT
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger('db_connector.py')
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
class RemoteDB:
|
||||||
|
def __init__(self):
|
||||||
|
self.ssh_host = SSH_HOST
|
||||||
|
self.ssh_username = SSH_USERNAME
|
||||||
|
self.ssh_password = SSH_PASSWORD
|
||||||
|
self.db_name = DB_NAME
|
||||||
|
self.db_user = DB_USER
|
||||||
|
self.db_password = DB_PASSWORD
|
||||||
|
self.db_host = DB_HOST
|
||||||
|
self.db_port = DB_PORT
|
||||||
|
self.ssh_port = SSH_PORT
|
||||||
|
self.tunnel = None
|
||||||
|
self.engine = None
|
||||||
|
self.Session = None
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
try:
|
||||||
|
self.tunnel = SSHTunnelForwarder(
|
||||||
|
(self.ssh_host, self.ssh_port),
|
||||||
|
ssh_username=self.ssh_username,
|
||||||
|
ssh_password=self.ssh_password,
|
||||||
|
remote_bind_address=(self.db_host, self.db_port)
|
||||||
|
)
|
||||||
|
self.tunnel.start()
|
||||||
|
|
||||||
|
local_port = self.tunnel.local_bind_port
|
||||||
|
db_url = f"postgresql://{self.db_user}:{self.db_password}@localhost:{local_port}/{self.db_name}"
|
||||||
|
self.engine = create_engine(db_url)
|
||||||
|
self.Session = sessionmaker(bind=self.engine)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Connection failed: {e}")
|
||||||
|
|
||||||
|
def execute_query(self, query):
|
||||||
|
session = self.Session()
|
||||||
|
try:
|
||||||
|
result = session.execute(query)
|
||||||
|
session.commit()
|
||||||
|
return result.fetchall()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.engine:
|
||||||
|
self.engine.dispose()
|
||||||
|
if self.tunnel:
|
||||||
|
self.tunnel.stop()
|
||||||
|
|||||||
@ -0,0 +1,43 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import geopandas as gpd
|
||||||
|
import os
|
||||||
|
import folium
|
||||||
|
from folium import plugins
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger('map.py')
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
|
||||||
|
accidents_filepath = "../src/datasets/integrated/Accidents.geojson"
|
||||||
|
signaled_speeds_filepath = "../src/datasets/integrated/signaled_speeds.geojson.geojson"
|
||||||
|
|
||||||
|
# Map centered around zurich
|
||||||
|
zurich_coordinates = [47.368650, 8.539183]
|
||||||
|
fixed_map_zurich_original_coords = folium.Map(
|
||||||
|
location=zurich_coordinates,
|
||||||
|
zoom_start=13,
|
||||||
|
zoom_control=False,
|
||||||
|
dragging=False,
|
||||||
|
scrollWheelZoom=False,
|
||||||
|
doubleClickZoom=False
|
||||||
|
)
|
||||||
|
def create_acc_map():
|
||||||
|
acc_gdf = gpd.read_file(accidents_filepath)
|
||||||
|
acc_gdf['latitude'] = acc_gdf.geometry.y
|
||||||
|
acc_gdf['longitude'] = acc_gdf.geometry.x
|
||||||
|
|
||||||
|
|
||||||
|
# Ensure we're dealing with floats
|
||||||
|
acc_gdf['latitude'] = acc_gdf['latitude'].astype(float)
|
||||||
|
acc_gdf['longitude'] = acc_gdf['longitude'].astype(float)
|
||||||
|
|
||||||
|
# Build heat dataframe used for mapping
|
||||||
|
heat_df = acc_gdf
|
||||||
|
heat_df = heat_df[['latitude', 'longitude']]
|
||||||
|
heat_df = heat_df.dropna(axis=0, subset=['latitude', 'longitude'])
|
||||||
|
heat_df = heat_df.dropna(axis=0, subset=['latitude', 'longitude'])
|
||||||
0
analysis/plots.py
Normal file
0
analysis/plots.py
Normal file
@ -135,3 +135,5 @@ Werkzeug==3.0.1
|
|||||||
widgetsnbextension==4.0.9
|
widgetsnbextension==4.0.9
|
||||||
xyzservices==2023.10.1
|
xyzservices==2023.10.1
|
||||||
zipp==3.17.0
|
zipp==3.17.0
|
||||||
|
|
||||||
|
sshtunnel~=0.4.0
|
||||||
Reference in New Issue
Block a user