From eccfba5bec04ac00979604d55eee472b90037eb1 Mon Sep 17 00:00:00 2001 From: Joey Hines Date: Sun, 15 Jul 2018 10:56:56 -0500 Subject: [PATCH] Added LocationNameNotUniqueError to handle unique constraint errors better --- BotErrors.py | 8 +++++--- DatabaseModels.py | 34 +++++++++++++++++++--------------- Geoffrey.py | 10 ++++------ test_geoffreyDatabase.py | 20 ++++++++++++++------ 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/BotErrors.py b/BotErrors.py index 774bf19..b5edf6a 100644 --- a/BotErrors.py +++ b/BotErrors.py @@ -2,15 +2,14 @@ class DataBaseError(Exception): '''Base class for exceptions in this module.''' pass - class LocationInitError(DataBaseError): '''Error in initializing Location''' -class LocationLookUpError(DataBaseError) : +class LocationLookUpError(DataBaseError): '''Error in finding location in database''' -class DeleteEntryError(DataBaseError) : +class DeleteEntryError(DataBaseError): '''Error in deleting entry''' class UsernameLookupFailed(Exception): @@ -19,4 +18,7 @@ class UsernameLookupFailed(Exception): class PlayerNotFound(DataBaseError): '''Player not found in database.''' +class LocationNameNotUniqueError(DataBaseError): + '''A location by that name is already in the database.''' + diff --git a/DatabaseModels.py b/DatabaseModels.py index 3409842..8d3460d 100644 --- a/DatabaseModels.py +++ b/DatabaseModels.py @@ -4,6 +4,7 @@ from sqlalchemy.ext.declarative import declarative_base from BotErrors import * from sqlalchemy import create_engine, exists, literal from sqlalchemy.orm import sessionmaker, relationship +from sqlalchemy.exc import IntegrityError import sqlalchemy from MinecraftAccountInfoGrabber import * @@ -77,9 +78,9 @@ class DatabaseInterface: expr = (Location.owner == owner) & (Location.name.ilike(name)) return self.database.query_by_filter(Location, expr) - def find_location_around(self, x_pos, z_pos, radius): + def find_location_around(self, x_pos, z_pos, radius, dimension): expr = (Location.x < x_pos + radius + 1) & (Location.x > x_pos - radius - 1) & (Location.z < z_pos + radius + 1) \ - & (Location.z > z_pos - radius - 1) + & (Location.z > z_pos - radius - 1) & (Location.dimension == dimension) return self.database.query_by_filter(Location, expr) @@ -193,10 +194,13 @@ class GeoffreyDatabase: SQL_Base.metadata.create_all(self.engine) def add_object(self, obj): - ret = not self.session.query(exists().where(type(obj).id == obj.id)) - if not ret: - self.session.add(obj) - self.session.commit() + try: + ret = not self.session.query(exists().where(type(obj).id == obj.id)) + if not ret: + self.session.add(obj) + self.session.commit() + except IntegrityError: + raise LocationNameNotUniqueError def query_by_filter(self, obj_type, * args): filter_value = self.combine_filter(args) @@ -232,13 +236,13 @@ class TunnelDirection(enum.Enum): def str_to_tunnel_dir(arg): arg = arg.lower() - if arg == TunnelDirection.North.value: + if arg in TunnelDirection.North.value: return TunnelDirection.North - elif arg == TunnelDirection.East.value: + elif arg in TunnelDirection.East.value: return TunnelDirection.East - elif arg == TunnelDirection.South.value: + elif arg in TunnelDirection.South.value: return TunnelDirection.South - elif arg == TunnelDirection.West.value: + elif arg in TunnelDirection.West.value: return TunnelDirection.West else: raise ValueError @@ -250,9 +254,9 @@ class TunnelSide(enum.Enum): def str_to_tunnel_side(arg): arg = arg.lower() - if arg == TunnelSide.right.value: + if arg in TunnelSide.right.value: return TunnelSide.right - elif arg == TunnelSide.left.value: + elif arg in TunnelSide.left.value: return TunnelSide.left else: raise ValueError @@ -265,11 +269,11 @@ class Dimension(enum.Enum): def str_to_dimension(arg): arg = arg.lower() - if arg == Dimension.overworld.value: + if arg in Dimension.overworld.value: return Dimension.overworld - elif arg == Dimension.nether.value: + elif arg in Dimension.nether.value: return Dimension.nether - elif arg == Dimension.end.value: + elif arg in Dimension.end.value: return Dimension.end else: raise ValueError diff --git a/Geoffrey.py b/Geoffrey.py index 0b6786b..509f558 100644 --- a/Geoffrey.py +++ b/Geoffrey.py @@ -3,7 +3,7 @@ from DatabaseModels import * from BotErrors import * from MinecraftAccountInfoGrabber import * import configparser -import sqlite3 +import shlex #from WebInterface import * TOKEN = '' @@ -48,12 +48,9 @@ async def on_command_error(error, ctx): elif isinstance(error.original, PlayerNotFound): error_str = 'Make sure to use ?register first you ding dong.' database_interface.database.session.rollback() - elif isinstance(error.original.orig, sqlite3.IntegrityError): + elif isinstance(error.original, LocationNameNotUniqueError): error_str = 'An entry in the database already has that name ding dong.' database_interface.database.session.rollback() - elif isinstance(error.original, sqlite3.IntegrityError): - error_str = 'Oof, the fuck did you do? Try the command again but be less of a ding dong with it.' - database_interface.database.session.rollback() else: error_str = bad_error_message.format(ctx.invoked_with, error) @@ -165,6 +162,7 @@ async def findaround(ctx, x_pos: int, z_pos: int, * args): ''' radius = 200 + if len(args) > 0: try: radius = int(args[0]) @@ -260,7 +258,6 @@ def create_config(): with open('GeoffreyConfig.ini', 'w') as configfile: config.write(configfile) - def get_engine_arg(config): driver = config['SQL']['Dialect+Driver'] username = config['SQL']['username'] @@ -294,5 +291,6 @@ else: database_interface = DiscordDatabaseInterface(engine_arg) #WebInterface('127.0.0.1', 8081, database) + bot.run(TOKEN) diff --git a/test_geoffreyDatabase.py b/test_geoffreyDatabase.py index 8dbfbc1..f654555 100644 --- a/test_geoffreyDatabase.py +++ b/test_geoffreyDatabase.py @@ -99,27 +99,29 @@ class TestGeoffreyDatabase(TestCase): owner = self.interface.add_player('ZeroHD', '143072699567177728') loc = self.interface.add_location('143072699567177728', 'test', 0, 0, 0, ['Green', 0, "Right"]) - loc_list = self.interface.find_location_around(100, 100, 100) + dim = Dimension.str_to_dimension("O") + + loc_list = self.interface.find_location_around(100, 100, 100, dim) self.assertEqual(loc_list[0].name, loc.name) - loc_list = self.interface.find_location_around(200, 200, 100) + loc_list = self.interface.find_location_around(200, 200, 100, dim) self.assertEqual(len(loc_list), 0) - loc_list = self.interface.find_location_around(-100, -100, 100) + loc_list = self.interface.find_location_around(-100, -100, 100, dim) self.assertEqual(loc_list[0].name, loc.name) - loc_list = self.interface.find_location_around(100, -100, 100) + loc_list = self.interface.find_location_around(100, -100, 100, dim) self.assertEqual(loc_list[0].name, loc.name) - loc_list = self.interface.find_location_around(-100, 100, 100) + loc_list = self.interface.find_location_around(-100, 100, 100, dim) self.assertEqual(loc_list[0].name, loc.name) - loc_list = self.interface.find_location_around(50, -50, 100) + loc_list = self.interface.find_location_around(50, -50, 100, dim) self.assertEqual(loc_list[0].name, loc.name) @@ -164,6 +166,12 @@ class TestGeoffreyDatabase(TestCase): self.assertEqual(loc_list[0].id, loc.id) + def test_duplicate_name(self): + self.interface.add_player('ZeroHD', '143072699567177728') + self.interface.add_location('143072699567177728', 'test', 0, 0, 0, ['Green', 0, "right"]) + + self.assertRaises(LocationNameNotUniqueError, self.interface.add_location, + '143072699567177728', 'test', 0, 0, 0, ['Green', 0, "right"])