Added LocationNameNotUniqueError to handle unique constraint errors better

doc_update
Joey Hines 2018-07-15 10:56:56 -05:00
parent e7beb95903
commit eccfba5bec
4 changed files with 42 additions and 30 deletions

View File

@ -2,15 +2,14 @@ class DataBaseError(Exception):
'''Base class for exceptions in this module.''' '''Base class for exceptions in this module.'''
pass pass
class LocationInitError(DataBaseError): class LocationInitError(DataBaseError):
'''Error in initializing Location''' '''Error in initializing Location'''
class LocationLookUpError(DataBaseError) : class LocationLookUpError(DataBaseError):
'''Error in finding location in database''' '''Error in finding location in database'''
class DeleteEntryError(DataBaseError) : class DeleteEntryError(DataBaseError):
'''Error in deleting entry''' '''Error in deleting entry'''
class UsernameLookupFailed(Exception): class UsernameLookupFailed(Exception):
@ -19,4 +18,7 @@ class UsernameLookupFailed(Exception):
class PlayerNotFound(DataBaseError): class PlayerNotFound(DataBaseError):
'''Player not found in database.''' '''Player not found in database.'''
class LocationNameNotUniqueError(DataBaseError):
'''A location by that name is already in the database.'''

View File

@ -4,6 +4,7 @@ from sqlalchemy.ext.declarative import declarative_base
from BotErrors import * from BotErrors import *
from sqlalchemy import create_engine, exists, literal from sqlalchemy import create_engine, exists, literal
from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.exc import IntegrityError
import sqlalchemy import sqlalchemy
from MinecraftAccountInfoGrabber import * from MinecraftAccountInfoGrabber import *
@ -77,9 +78,9 @@ class DatabaseInterface:
expr = (Location.owner == owner) & (Location.name.ilike(name)) expr = (Location.owner == owner) & (Location.name.ilike(name))
return self.database.query_by_filter(Location, expr) return self.database.query_by_filter(Location, expr)
def find_location_around(self, x_pos, z_pos, radius): def find_location_around(self, x_pos, z_pos, radius, dimension):
expr = (Location.x < x_pos + radius + 1) & (Location.x > x_pos - radius - 1) & (Location.z < z_pos + radius + 1) \ expr = (Location.x < x_pos + radius + 1) & (Location.x > x_pos - radius - 1) & (Location.z < z_pos + radius + 1) \
& (Location.z > z_pos - radius - 1) & (Location.z > z_pos - radius - 1) & (Location.dimension == dimension)
return self.database.query_by_filter(Location, expr) return self.database.query_by_filter(Location, expr)
@ -193,10 +194,13 @@ class GeoffreyDatabase:
SQL_Base.metadata.create_all(self.engine) SQL_Base.metadata.create_all(self.engine)
def add_object(self, obj): def add_object(self, obj):
ret = not self.session.query(exists().where(type(obj).id == obj.id)) try:
if not ret: ret = not self.session.query(exists().where(type(obj).id == obj.id))
self.session.add(obj) if not ret:
self.session.commit() self.session.add(obj)
self.session.commit()
except IntegrityError:
raise LocationNameNotUniqueError
def query_by_filter(self, obj_type, * args): def query_by_filter(self, obj_type, * args):
filter_value = self.combine_filter(args) filter_value = self.combine_filter(args)
@ -232,13 +236,13 @@ class TunnelDirection(enum.Enum):
def str_to_tunnel_dir(arg): def str_to_tunnel_dir(arg):
arg = arg.lower() arg = arg.lower()
if arg == TunnelDirection.North.value: if arg in TunnelDirection.North.value:
return TunnelDirection.North return TunnelDirection.North
elif arg == TunnelDirection.East.value: elif arg in TunnelDirection.East.value:
return TunnelDirection.East return TunnelDirection.East
elif arg == TunnelDirection.South.value: elif arg in TunnelDirection.South.value:
return TunnelDirection.South return TunnelDirection.South
elif arg == TunnelDirection.West.value: elif arg in TunnelDirection.West.value:
return TunnelDirection.West return TunnelDirection.West
else: else:
raise ValueError raise ValueError
@ -250,9 +254,9 @@ class TunnelSide(enum.Enum):
def str_to_tunnel_side(arg): def str_to_tunnel_side(arg):
arg = arg.lower() arg = arg.lower()
if arg == TunnelSide.right.value: if arg in TunnelSide.right.value:
return TunnelSide.right return TunnelSide.right
elif arg == TunnelSide.left.value: elif arg in TunnelSide.left.value:
return TunnelSide.left return TunnelSide.left
else: else:
raise ValueError raise ValueError
@ -265,11 +269,11 @@ class Dimension(enum.Enum):
def str_to_dimension(arg): def str_to_dimension(arg):
arg = arg.lower() arg = arg.lower()
if arg == Dimension.overworld.value: if arg in Dimension.overworld.value:
return Dimension.overworld return Dimension.overworld
elif arg == Dimension.nether.value: elif arg in Dimension.nether.value:
return Dimension.nether return Dimension.nether
elif arg == Dimension.end.value: elif arg in Dimension.end.value:
return Dimension.end return Dimension.end
else: else:
raise ValueError raise ValueError

View File

@ -3,7 +3,7 @@ from DatabaseModels import *
from BotErrors import * from BotErrors import *
from MinecraftAccountInfoGrabber import * from MinecraftAccountInfoGrabber import *
import configparser import configparser
import sqlite3 import shlex
#from WebInterface import * #from WebInterface import *
TOKEN = '' TOKEN = ''
@ -48,12 +48,9 @@ async def on_command_error(error, ctx):
elif isinstance(error.original, PlayerNotFound): elif isinstance(error.original, PlayerNotFound):
error_str = 'Make sure to use ?register first you ding dong.' error_str = 'Make sure to use ?register first you ding dong.'
database_interface.database.session.rollback() database_interface.database.session.rollback()
elif isinstance(error.original.orig, sqlite3.IntegrityError): elif isinstance(error.original, LocationNameNotUniqueError):
error_str = 'An entry in the database already has that name ding dong.' error_str = 'An entry in the database already has that name ding dong.'
database_interface.database.session.rollback() database_interface.database.session.rollback()
elif isinstance(error.original, sqlite3.IntegrityError):
error_str = 'Oof, the fuck did you do? Try the command again but be less of a ding dong with it.'
database_interface.database.session.rollback()
else: else:
error_str = bad_error_message.format(ctx.invoked_with, error) error_str = bad_error_message.format(ctx.invoked_with, error)
@ -165,6 +162,7 @@ async def findaround(ctx, x_pos: int, z_pos: int, * args):
''' '''
radius = 200 radius = 200
if len(args) > 0: if len(args) > 0:
try: try:
radius = int(args[0]) radius = int(args[0])
@ -260,7 +258,6 @@ def create_config():
with open('GeoffreyConfig.ini', 'w') as configfile: with open('GeoffreyConfig.ini', 'w') as configfile:
config.write(configfile) config.write(configfile)
def get_engine_arg(config): def get_engine_arg(config):
driver = config['SQL']['Dialect+Driver'] driver = config['SQL']['Dialect+Driver']
username = config['SQL']['username'] username = config['SQL']['username']
@ -294,5 +291,6 @@ else:
database_interface = DiscordDatabaseInterface(engine_arg) database_interface = DiscordDatabaseInterface(engine_arg)
#WebInterface('127.0.0.1', 8081, database) #WebInterface('127.0.0.1', 8081, database)
bot.run(TOKEN) bot.run(TOKEN)

View File

@ -99,27 +99,29 @@ class TestGeoffreyDatabase(TestCase):
owner = self.interface.add_player('ZeroHD', '143072699567177728') owner = self.interface.add_player('ZeroHD', '143072699567177728')
loc = self.interface.add_location('143072699567177728', 'test', 0, 0, 0, ['Green', 0, "Right"]) loc = self.interface.add_location('143072699567177728', 'test', 0, 0, 0, ['Green', 0, "Right"])
loc_list = self.interface.find_location_around(100, 100, 100) dim = Dimension.str_to_dimension("O")
loc_list = self.interface.find_location_around(100, 100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(200, 200, 100) loc_list = self.interface.find_location_around(200, 200, 100, dim)
self.assertEqual(len(loc_list), 0) self.assertEqual(len(loc_list), 0)
loc_list = self.interface.find_location_around(-100, -100, 100) loc_list = self.interface.find_location_around(-100, -100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(100, -100, 100) loc_list = self.interface.find_location_around(100, -100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(-100, 100, 100) loc_list = self.interface.find_location_around(-100, 100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(50, -50, 100) loc_list = self.interface.find_location_around(50, -50, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
@ -164,6 +166,12 @@ class TestGeoffreyDatabase(TestCase):
self.assertEqual(loc_list[0].id, loc.id) self.assertEqual(loc_list[0].id, loc.id)
def test_duplicate_name(self):
self.interface.add_player('ZeroHD', '143072699567177728')
self.interface.add_location('143072699567177728', 'test', 0, 0, 0, ['Green', 0, "right"])
self.assertRaises(LocationNameNotUniqueError, self.interface.add_location,
'143072699567177728', 'test', 0, 0, 0, ['Green', 0, "right"])