diff --git a/analysis/plots.py b/analysis/plots.py index 4a17afe..bda3854 100644 --- a/analysis/plots.py +++ b/analysis/plots.py @@ -9,6 +9,7 @@ import numpy as np logging.getLogger("matplotlib").setLevel(logging.WARNING) + # TODO def plt_acc_by_year(db): acc_year_sql = """ @@ -18,10 +19,10 @@ def plt_acc_by_year(db): """ result = db.execute_query(acc_year_sql) result_df = pd.DataFrame(result) - plt.barh(result_df['accidentyear'],result_df['count']) - plt.ylabel('Year') - plt.xlabel('No. of Accidents') - plt.show() + + fig = px.bar(result_df, y='year', x='count', orientation='h', title='No. of Accidents per Year') + fig.write_image("fig/acc_by_year.png") + fig.write_html("html/acc_by_year.png") def plt_acc_by_weekday(db): @@ -35,12 +36,18 @@ def plt_acc_by_weekday(db): result = db.execute_query(acc_weekday_sql) result_df = pd.DataFrame(result) - plt.barh(result_df['weekday'], result_df['count']) - plt.ylabel('Weekday') - plt.xlabel('No. of Accidents') - plt.show() + fig = px.bar(result_df, y='weekday', x='Count', orientation='h', title='No. of Accidents per Weekday') + fig.write_image("fig/acc_by_weekday.png") + fig.write_html("html/acc_by_weekday.html") +def plt_acc_by_day_year(db): + acc_year_day_sql = """ + SELECT accidentyear AS year, accidentweekday_en AS weekday, COUNT(*) AS count + FROM accidents + GROUP BY weekday, year + ORDER BY year, COUNT(*); + """ def plt_acc_by_daytime(db): @@ -54,24 +61,27 @@ def plt_acc_by_daytime(db): result = db.execute_query(acc_weekday_sql) result_df = pd.DataFrame(result) - plt.barh(result_df['hour'], result_df['count']) - plt.ylabel('hour') - plt.xlabel('No. of Accidents') - plt.show() - fig = px.bar(result_df, y='hour', x='count', orientation='h') fig.write_image("fig/acc_by_day.png") fig.write_html("html/acc_by_day.html") +# Utilities =========================================================================================================== +def save_as_barplot(df, xname, yname, orientation, file_name): + pass + + +def save_as_html(): + pass + + if __name__ == "__main__": remote_db = RemoteDB() try: - #plt_acc_by_year(remote_db) - #plt_acc_by_weekday(remote_db) + # plt_acc_by_year(remote_db) + # plt_acc_by_weekday(remote_db) plt_acc_by_daytime(remote_db) except Exception as e: print(f"Exception {e} in plots.py") finally: remote_db.close() -