Add db_connector.py body and beginnings of map.py. As well as config_example.py to use as central source of truth for connection credentials.
This commit is contained in:
parent
ab1baacf47
commit
92e54174cc
11
analysis/config_example.py
Normal file
11
analysis/config_example.py
Normal file
@ -0,0 +1,11 @@
|
||||
# config.py, adjust as needed
|
||||
# TODO RENAME THIS FILE TO "config.py"
|
||||
SSH_HOST = 'slenzlinger.dev'
|
||||
SSH_USERNAME = 'sebl' #TODO: Enter own username
|
||||
SSH_PASSWORD = 'your_ssh_password' # TODO: to not push to git
|
||||
DB_NAME = 'your_database_name' # TODO
|
||||
DB_USER = 'your_database_username' # TODO
|
||||
DB_PASSWORD = 'your_database_password' # TODO
|
||||
DB_HOST = 'your_database_host' # TODO
|
||||
DB_PORT = 5433
|
||||
SSH_PORT = 22
|
||||
@ -1,2 +1,63 @@
|
||||
import sqlalchemy as db
|
||||
import pandas as pd
|
||||
import logging
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from config import SSH_HOST, SSH_USERNAME, SSH_PASSWORD, DB_NAME, DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, SSH_PORT
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('db_connector.py')
|
||||
stream_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
class RemoteDB:
|
||||
def __init__(self):
|
||||
self.ssh_host = SSH_HOST
|
||||
self.ssh_username = SSH_USERNAME
|
||||
self.ssh_password = SSH_PASSWORD
|
||||
self.db_name = DB_NAME
|
||||
self.db_user = DB_USER
|
||||
self.db_password = DB_PASSWORD
|
||||
self.db_host = DB_HOST
|
||||
self.db_port = DB_PORT
|
||||
self.ssh_port = SSH_PORT
|
||||
self.tunnel = None
|
||||
self.engine = None
|
||||
self.Session = None
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
try:
|
||||
self.tunnel = SSHTunnelForwarder(
|
||||
(self.ssh_host, self.ssh_port),
|
||||
ssh_username=self.ssh_username,
|
||||
ssh_password=self.ssh_password,
|
||||
remote_bind_address=(self.db_host, self.db_port)
|
||||
)
|
||||
self.tunnel.start()
|
||||
|
||||
local_port = self.tunnel.local_bind_port
|
||||
db_url = f"postgresql://{self.db_user}:{self.db_password}@localhost:{local_port}/{self.db_name}"
|
||||
self.engine = create_engine(db_url)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
except Exception as e:
|
||||
logger.exception(f"Connection failed: {e}")
|
||||
|
||||
def execute_query(self, query):
|
||||
session = self.Session()
|
||||
try:
|
||||
result = session.execute(query)
|
||||
session.commit()
|
||||
return result.fetchall()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def close(self):
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
if self.tunnel:
|
||||
self.tunnel.stop()
|
||||
|
||||
@ -0,0 +1,43 @@
|
||||
import pandas as pd
|
||||
import geopandas as gpd
|
||||
import os
|
||||
import folium
|
||||
from folium import plugins
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('map.py')
|
||||
stream_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
accidents_filepath = "../src/datasets/integrated/Accidents.geojson"
|
||||
signaled_speeds_filepath = "../src/datasets/integrated/signaled_speeds.geojson.geojson"
|
||||
|
||||
# Map centered around zurich
|
||||
zurich_coordinates = [47.368650, 8.539183]
|
||||
fixed_map_zurich_original_coords = folium.Map(
|
||||
location=zurich_coordinates,
|
||||
zoom_start=13,
|
||||
zoom_control=False,
|
||||
dragging=False,
|
||||
scrollWheelZoom=False,
|
||||
doubleClickZoom=False
|
||||
)
|
||||
def create_acc_map():
|
||||
acc_gdf = gpd.read_file(accidents_filepath)
|
||||
acc_gdf['latitude'] = acc_gdf.geometry.y
|
||||
acc_gdf['longitude'] = acc_gdf.geometry.x
|
||||
|
||||
|
||||
# Ensure we're dealing with floats
|
||||
acc_gdf['latitude'] = acc_gdf['latitude'].astype(float)
|
||||
acc_gdf['longitude'] = acc_gdf['longitude'].astype(float)
|
||||
|
||||
# Build heat dataframe used for mapping
|
||||
heat_df = acc_gdf
|
||||
heat_df = heat_df[['latitude', 'longitude']]
|
||||
heat_df = heat_df.dropna(axis=0, subset=['latitude', 'longitude'])
|
||||
heat_df = heat_df.dropna(axis=0, subset=['latitude', 'longitude'])
|
||||
0
analysis/plots.py
Normal file
0
analysis/plots.py
Normal file
@ -135,3 +135,5 @@ Werkzeug==3.0.1
|
||||
widgetsnbextension==4.0.9
|
||||
xyzservices==2023.10.1
|
||||
zipp==3.17.0
|
||||
|
||||
sshtunnel~=0.4.0
|
||||
Reference in New Issue
Block a user