Added LocationNameNotUniqueError to handle unique constraint errors better

doc_update
Joey Hines 2018-07-15 10:56:56 -05:00
parent e7beb95903
commit eccfba5bec
4 changed files with 42 additions and 30 deletions

View File

@ -2,15 +2,14 @@ class DataBaseError(Exception):
'''Base class for exceptions in this module.'''
pass
class LocationInitError(DataBaseError):
'''Error in initializing Location'''
class LocationLookUpError(DataBaseError) :
class LocationLookUpError(DataBaseError):
'''Error in finding location in database'''
class DeleteEntryError(DataBaseError) :
class DeleteEntryError(DataBaseError):
'''Error in deleting entry'''
class UsernameLookupFailed(Exception):
@ -19,4 +18,7 @@ class UsernameLookupFailed(Exception):
class PlayerNotFound(DataBaseError):
'''Player not found in database.'''
class LocationNameNotUniqueError(DataBaseError):
'''A location by that name is already in the database.'''

View File

@ -4,6 +4,7 @@ from sqlalchemy.ext.declarative import declarative_base
from BotErrors import *
from sqlalchemy import create_engine, exists, literal
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.exc import IntegrityError
import sqlalchemy
from MinecraftAccountInfoGrabber import *
@ -77,9 +78,9 @@ class DatabaseInterface:
expr = (Location.owner == owner) & (Location.name.ilike(name))
return self.database.query_by_filter(Location, expr)
def find_location_around(self, x_pos, z_pos, radius):
def find_location_around(self, x_pos, z_pos, radius, dimension):
expr = (Location.x < x_pos + radius + 1) & (Location.x > x_pos - radius - 1) & (Location.z < z_pos + radius + 1) \
& (Location.z > z_pos - radius - 1)
& (Location.z > z_pos - radius - 1) & (Location.dimension == dimension)
return self.database.query_by_filter(Location, expr)
@ -193,10 +194,13 @@ class GeoffreyDatabase:
SQL_Base.metadata.create_all(self.engine)
def add_object(self, obj):
ret = not self.session.query(exists().where(type(obj).id == obj.id))
if not ret:
self.session.add(obj)
self.session.commit()
try:
ret = not self.session.query(exists().where(type(obj).id == obj.id))
if not ret:
self.session.add(obj)
self.session.commit()
except IntegrityError:
raise LocationNameNotUniqueError
def query_by_filter(self, obj_type, * args):
filter_value = self.combine_filter(args)
@ -232,13 +236,13 @@ class TunnelDirection(enum.Enum):
def str_to_tunnel_dir(arg):
arg = arg.lower()
if arg == TunnelDirection.North.value:
if arg in TunnelDirection.North.value:
return TunnelDirection.North
elif arg == TunnelDirection.East.value:
elif arg in TunnelDirection.East.value:
return TunnelDirection.East
elif arg == TunnelDirection.South.value:
elif arg in TunnelDirection.South.value:
return TunnelDirection.South
elif arg == TunnelDirection.West.value:
elif arg in TunnelDirection.West.value:
return TunnelDirection.West
else:
raise ValueError
@ -250,9 +254,9 @@ class TunnelSide(enum.Enum):
def str_to_tunnel_side(arg):
arg = arg.lower()
if arg == TunnelSide.right.value:
if arg in TunnelSide.right.value:
return TunnelSide.right
elif arg == TunnelSide.left.value:
elif arg in TunnelSide.left.value:
return TunnelSide.left
else:
raise ValueError
@ -265,11 +269,11 @@ class Dimension(enum.Enum):
def str_to_dimension(arg):
arg = arg.lower()
if arg == Dimension.overworld.value:
if arg in Dimension.overworld.value:
return Dimension.overworld
elif arg == Dimension.nether.value:
elif arg in Dimension.nether.value:
return Dimension.nether
elif arg == Dimension.end.value:
elif arg in Dimension.end.value:
return Dimension.end
else:
raise ValueError

View File

@ -3,7 +3,7 @@ from DatabaseModels import *
from BotErrors import *
from MinecraftAccountInfoGrabber import *
import configparser
import sqlite3
import shlex
#from WebInterface import *
TOKEN = ''
@ -48,12 +48,9 @@ async def on_command_error(error, ctx):
elif isinstance(error.original, PlayerNotFound):
error_str = 'Make sure to use ?register first you ding dong.'
database_interface.database.session.rollback()
elif isinstance(error.original.orig, sqlite3.IntegrityError):
elif isinstance(error.original, LocationNameNotUniqueError):
error_str = 'An entry in the database already has that name ding dong.'
database_interface.database.session.rollback()
elif isinstance(error.original, sqlite3.IntegrityError):
error_str = 'Oof, the fuck did you do? Try the command again but be less of a ding dong with it.'
database_interface.database.session.rollback()
else:
error_str = bad_error_message.format(ctx.invoked_with, error)
@ -165,6 +162,7 @@ async def findaround(ctx, x_pos: int, z_pos: int, * args):
'''
radius = 200
if len(args) > 0:
try:
radius = int(args[0])
@ -260,7 +258,6 @@ def create_config():
with open('GeoffreyConfig.ini', 'w') as configfile:
config.write(configfile)
def get_engine_arg(config):
driver = config['SQL']['Dialect+Driver']
username = config['SQL']['username']
@ -294,5 +291,6 @@ else:
database_interface = DiscordDatabaseInterface(engine_arg)
#WebInterface('127.0.0.1', 8081, database)
bot.run(TOKEN)

View File

@ -99,27 +99,29 @@ class TestGeoffreyDatabase(TestCase):
owner = self.interface.add_player('ZeroHD', '143072699567177728')
loc = self.interface.add_location('143072699567177728', 'test', 0, 0, 0, ['Green', 0, "Right"])
loc_list = self.interface.find_location_around(100, 100, 100)
dim = Dimension.str_to_dimension("O")
loc_list = self.interface.find_location_around(100, 100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(200, 200, 100)
loc_list = self.interface.find_location_around(200, 200, 100, dim)
self.assertEqual(len(loc_list), 0)
loc_list = self.interface.find_location_around(-100, -100, 100)
loc_list = self.interface.find_location_around(-100, -100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(100, -100, 100)
loc_list = self.interface.find_location_around(100, -100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(-100, 100, 100)
loc_list = self.interface.find_location_around(-100, 100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(50, -50, 100)
loc_list = self.interface.find_location_around(50, -50, 100, dim)
self.assertEqual(loc_list[0].name, loc.name)
@ -164,6 +166,12 @@ class TestGeoffreyDatabase(TestCase):
self.assertEqual(loc_list[0].id, loc.id)
def test_duplicate_name(self):
self.interface.add_player('ZeroHD', '143072699567177728')
self.interface.add_location('143072699567177728', 'test', 0, 0, 0, ['Green', 0, "right"])
self.assertRaises(LocationNameNotUniqueError, self.interface.add_location,
'143072699567177728', 'test', 0, 0, 0, ['Green', 0, "right"])