Refactored session handling to avoid disconnect issues

doc_update
Joey Hines 2018-07-21 20:44:43 -05:00
parent 76df264060
commit 6e758de772
3 changed files with 195 additions and 175 deletions

View File

@ -26,111 +26,111 @@ class DatabaseInterface:
def __init__(self, db_engine_arg): def __init__(self, db_engine_arg):
self.database = GeoffreyDatabase(db_engine_arg) self.database = GeoffreyDatabase(db_engine_arg)
def add_location(self, owner, name, x_pos, z_pos, dimension=None): def add_location(self, session, owner, name, x_pos, z_pos, dimension=None):
location = Location(name, x_pos, z_pos, owner, dimension) location = Location(name, x_pos, z_pos, owner, dimension)
self.database.add_object(location) self.database.add_object(session, location)
return location return location
def add_shop(self, owner, name, x_pos, z_pos, dimension=None): def add_shop(self, session, owner, name, x_pos, z_pos, dimension=None):
shop = Shop(name, x_pos, z_pos, owner, dimension) shop = Shop(name, x_pos, z_pos, owner, dimension)
self.database.add_object(shop) self.database.add_object(session, shop)
return shop return shop
def add_tunnel(self, owner, color, number, location_name): def add_tunnel(self, session, owner, color, number, location_name):
if location_name is None: if location_name is None:
if len(self.find_tunnel_by_owner(owner)): if len(self.find_tunnel_by_owner(session, owner)):
raise EntryNameNotUniqueError raise EntryNameNotUniqueError
else: else:
location = None location = None
else: else:
try: try:
location = self.find_location_by_name_and_owner(owner, location_name)[0] location = self.find_location_by_name_and_owner(session, owner, location_name)[0]
except IndexError: except IndexError:
raise LocationLookUpError raise LocationLookUpError
tunnel = Tunnel(owner, color, number, location) tunnel = Tunnel(owner, color, number, location)
self.database.add_object(tunnel) self.database.add_object(session, tunnel)
return tunnel return tunnel
def add_item(self, owner, shop_name, item_name, price, amount): def add_item(self, session, owner, shop_name, item_name, price, amount):
try: try:
shop = self.find_shop_by_name_and_owner(owner, shop_name) shop = self.find_shop_by_name_and_owner(session, owner, shop_name)
item = ItemListing(item_name, price, amount, shop[0]) item = ItemListing(item_name, price, amount, shop[0])
self.database.add_object(item) self.database.add_object(session, item)
except IndexError: except IndexError:
raise LocationLookUpError raise LocationLookUpError
return item return item
def add_player(self, player_name, discord_id): def add_player(self, session, player_name, discord_id):
try: try:
player = self.find_player(player_name) player = self.find_player(session, player_name)
except PlayerNotFound: except PlayerNotFound:
try: try:
uuid = grab_UUID(player_name) uuid = grab_UUID(player_name)
player = self.find_player_by_mc_uuid(uuid) player = self.find_player_by_mc_uuid(session, uuid)
except PlayerNotFound: except PlayerNotFound:
player = Player(player_name) player = Player(player_name)
self.database.add_object(player, discord_id) self.database.add_object(session, player, discord_id)
finally: finally:
player.name = player_name player.name = player_name
self.database.session.commit() self.database.session.commit()
return player return player
def find_location_by_name(self, name): def find_location_by_name(self, session, name):
expr = Location.name.ilike('%{}%'.format(name)) expr = Location.name.ilike('%{}%'.format(name))
return self.database.query_by_filter(Location, expr) return self.database.query_by_filter(session, Location, expr)
def find_shop_by_name(self, name): def find_shop_by_name(self, session, name):
expr = Location.name.ilike('%{}%'.format(name)) expr = Location.name.ilike('%{}%'.format(name))
return self.database.query_by_filter(Shop, expr) return self.database.query_by_filter(session, Shop, expr)
def find_location_by_owner(self, owner): def find_location_by_owner(self, session, owner):
expr = Location.owner == owner expr = Location.owner == owner
return self.database.query_by_filter(Location, expr) return self.database.query_by_filter(session, Location, expr)
def find_shop_by_owner(self, owner): def find_shop_by_owner(self, session, owner):
expr = Shop.owner == owner expr = Shop.owner == owner
return self.database.query_by_filter(Shop, expr) return self.database.query_by_filter(session, Shop, expr)
def find_location_by_owner_name(self, owner_name): def find_location_by_owner_name(self, session, owner_name):
expr = Location.owner.has(Player.name.ilike(owner_name)) expr = Location.owner.has(Player.name.ilike(owner_name))
return self.database.query_by_filter(Location, expr) return self.database.query_by_filter(session, Location, expr)
def find_shop_by_name_and_owner(self, owner, name): def find_shop_by_name_and_owner(self, session, owner, name):
expr = (Shop.owner == owner) & (Shop.name.ilike(name)) expr = (Shop.owner == owner) & (Shop.name.ilike(name))
return self.database.query_by_filter(Shop, expr) return self.database.query_by_filter(session, Shop, expr)
def find_location_by_name_and_owner(self, owner, name): def find_location_by_name_and_owner(self, session, owner, name):
expr = (Location.owner == owner) & (Location.name.ilike(name)) expr = (Location.owner == owner) & (Location.name.ilike(name))
return self.database.query_by_filter(Location, expr) return self.database.query_by_filter(session, Location, expr)
def find_location_around(self, x_pos, z_pos, radius, dimension): def find_location_around(self, session, x_pos, z_pos, radius, dimension):
dimension_obj = Dimension.str_to_dimension(dimension) dimension_obj = Dimension.str_to_dimension(dimension)
expr = (Location.x < x_pos + radius + 1) & (Location.x > x_pos - radius - 1) & (Location.z < z_pos + radius + 1) \ expr = (Location.x < x_pos + radius + 1) & (Location.x > x_pos - radius - 1) & (Location.z < z_pos + radius + 1) \
& (Location.z > z_pos - radius - 1) & (Location.dimension == dimension_obj) & (Location.z > z_pos - radius - 1) & (Location.dimension == dimension_obj)
return self.database.query_by_filter(Location, expr) return self.database.query_by_filter(session, Location, expr)
def find_tunnel_by_owner(self, owner): def find_tunnel_by_owner(self, session, owner):
expr = Tunnel.owner == owner expr = Tunnel.owner == owner
return self.database.query_by_filter(Tunnel, expr) return self.database.query_by_filter(session, Tunnel, expr)
def find_tunnel_by_owner_name(self, owner_name): def find_tunnel_by_owner_name(self, session, owner_name):
expr = Tunnel.owner.has(Player.name.ilike(owner_name)) expr = Tunnel.owner.has(Player.name.ilike(owner_name))
return self.database.query_by_filter(Tunnel, expr) return self.database.query_by_filter(session, Tunnel, expr)
def find_item(self, item_name): def find_item(self, session, item_name):
expr = ItemListing.name.ilike('%{}%'.format(item_name)) expr = ItemListing.name.ilike('%{}%'.format(item_name))
return self.database.query_by_filter(ItemListing, expr) return self.database.query_by_filter(session, ItemListing, expr)
def find_shop_selling_item(self, item_name): def find_shop_selling_item(self, session, item_name):
listings = self.find_item(item_name) listings = self.find_item(session, item_name)
shops = [] shops = []
for listing in listings: for listing in listings:
@ -138,46 +138,46 @@ class DatabaseInterface:
return shops return shops
def find_player(self, player_name): def find_player(self, session, player_name):
expr = Player.name.ilike(player_name) expr = Player.name.ilike(player_name)
try: try:
player = self.database.query_by_filter(Player, expr)[0] player = self.database.query_by_filter(session, Player, expr)[0]
except IndexError: except IndexError:
raise PlayerNotFound raise PlayerNotFound
return player return player
def find_player_by_mc_uuid(self, uuid): def find_player_by_mc_uuid(self, session, uuid):
expr = Player.id == uuid expr = Player.id == uuid
try: try:
player = self.database.query_by_filter(Player, expr)[0] player = self.database.query_by_filter(session, Player, expr)[0]
except IndexError: except IndexError:
raise PlayerNotFound raise PlayerNotFound
return player return player
def find_player_by_discord_uuid(self, uuid): def find_player_by_discord_uuid(self, session, uuid):
expr = Player.discord_uuid == uuid expr = Player.discord_uuid == uuid
try: try:
player = self.database.query_by_filter(Player, expr)[0] player = self.database.query_by_filter(session, Player, expr)[0]
except IndexError: except IndexError:
raise PlayerNotFound raise PlayerNotFound
return player return player
def search_all_fields(self, search): def search_all_fields(self, session, search):
loc_string = '' loc_string = ''
count = 0 count = 0
expr = Location.owner.has(Player.name.ilike('%{}%'.format(search))) | Location.name.ilike('%{}%'.format(search)) expr = Location.owner.has(Player.name.ilike('%{}%'.format(search))) | Location.name.ilike('%{}%'.format(search))
for loc in self.database.query_by_filter(Location, expr): for loc in self.database.query_by_filter(session, Location, expr):
loc_string = "{}\n{}".format(loc_string, loc) loc_string = "{}\n{}".format(loc_string, loc)
count += 1 count += 1
expr = Tunnel.owner.has(Player.name.ilike('%{}%'.format(search))) & Tunnel.location is None expr = Tunnel.owner.has(Player.name.ilike('%{}%'.format(search))) & Tunnel.location is None
for loc in self.database.query_by_filter(Tunnel, expr): for loc in self.database.query_by_filter(session, Tunnel, expr):
loc_string = "{}\n{}".format(loc_string, loc) loc_string = "{}\n{}".format(loc_string, loc)
count += 1 count += 1
@ -186,113 +186,111 @@ class DatabaseInterface:
else: else:
return loc_string return loc_string
def delete_location(self, owner, name): def delete_location(self, session, owner, name):
expr = (Location.owner == owner) & (Location.name == name) expr = (Location.owner == owner) & (Location.name == name)
self.database.delete_entry(Location, expr) self.database.delete_entry(session, Location, expr)
class DiscordDatabaseInterface(DatabaseInterface): class DiscordDatabaseInterface(DatabaseInterface):
def add_location(self, owner_uuid, name, x_pos, z_pos, dimension=None):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid)
return DatabaseInterface.add_location(self, owner, name, x_pos, z_pos, dimension)
def add_shop(self, owner_uuid, name, x_pos, z_pos, dimension=None): def add_location(self, session, owner_uuid, name, x_pos, z_pos, dimension=None):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.add_shop(self, owner, name, x_pos, z_pos, dimension) return DatabaseInterface.add_location(self, session, owner, name, x_pos, z_pos, dimension)
def add_tunnel(self, owner_uuid, color, number, location_name=""): def add_shop(self, session, owner_uuid, name, x_pos, z_pos, dimension=None):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.add_tunnel(self, owner, color, number, location_name) return DatabaseInterface.add_shop(self, session, owner, name, x_pos, z_pos, dimension)
def add_item(self, owner_uuid, shop_name, item_name, price, amount): def add_tunnel(self, session, owner_uuid, color, number, location_name=""):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.add_item(self, owner, shop_name, item_name, price, amount) return DatabaseInterface.add_tunnel(self, session, owner, color, number, location_name)
def add_player(self, player_name, discord_id): def add_item(self, session, owner_uuid, shop_name, item_name, price, amount):
owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.add_item(self, session, owner, shop_name, item_name, price, amount)
def add_player(self, session, player_name, discord_id):
try: try:
player = self.find_player(player_name) player = self.find_player(session, player_name)
except PlayerNotFound: except PlayerNotFound:
try: try:
uuid = grab_UUID(player_name) uuid = grab_UUID(player_name)
player = self.find_player_by_mc_uuid(uuid) player = self.find_player_by_mc_uuid(session, uuid)
except PlayerNotFound: except PlayerNotFound:
player = Player(player_name, discord_id) player = Player(player_name, discord_id)
self.database.add_object(player) self.database.add_object(session, player)
finally: finally:
player.name = player_name player.name = player_name
return player return player
def find_location_by_owner_uuid(self, owner_uuid): def find_location_by_owner_uuid(self, session, owner_uuid):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.find_location_by_owner(self, owner) return DatabaseInterface.find_location_by_owner(self, session, owner)
def find_shop_by_owner_uuid(self, owner_uuid): def find_shop_by_owner_uuid(self, session, owner_uuid):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.find_shop_by_owner(self, owner) return DatabaseInterface.find_shop_by_owner(self, session, owner)
def find_shop_by_name_and_owner_uuid(self, owner_uuid, name): def find_shop_by_name_and_owner_uuid(self, session, owner_uuid, name):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.find_shop_by_name_and_owner(self, owner, name) return DatabaseInterface.find_shop_by_name_and_owner(self, session, owner, name)
def find_location_by_name_and_owner_uuid(self, owner_uuid, name): def find_location_by_name_and_owner_uuid(self, session, owner_uuid, name):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.find_location_by_name_and_owner(self, owner, name) return DatabaseInterface.find_location_by_name_and_owner(self, session, owner, name)
def delete_location(self, owner_uuid, name): def delete_location(self, session, owner_uuid, name):
owner = DatabaseInterface.find_player_by_discord_uuid(self, owner_uuid) owner = DatabaseInterface.find_player_by_discord_uuid(self, session, owner_uuid)
return DatabaseInterface.delete_location(self, owner, name) return DatabaseInterface.delete_location(self, session, owner, name)
class GeoffreyDatabase: class GeoffreyDatabase:
def __init__(self, engine_arg): def __init__(self, engine_arg):
self.engine = create_engine(engine_arg, echo=True, pool_recycle=3600, pool_pre_ping=True) self.engine = create_engine(engine_arg, echo=True, pool_recycle=3600, pool_pre_ping=True)
Session = sessionmaker(bind=self.engine) self.Session = sessionmaker(bind=self.engine)
self.session = Session()
self.meta = MetaData()
SQL_Base.metadata.create_all(self.engine) SQL_Base.metadata.create_all(self.engine)
def clear_all(self): def clear_all(self, session):
self.session.query(Tunnel).delete() session.query(Tunnel).delete()
self.session.query(ItemListing).delete() session.query(ItemListing).delete()
self.session.query(Shop).delete() session.query(Shop).delete()
self.session.query(Location).delete() session.query(Location).delete()
self.session.query(Player).delete() session.query(Player).delete()
self.session.commit() session.commit()
def add_object(self, obj): def add_object(self, session, obj):
try: try:
ret = not self.session.query(exists().where(type(obj).id == obj.id)) ret = not session.query(exists().where(type(obj).id == obj.id))
if not ret: if not ret:
self.session.add(obj) session.add(obj)
self.session.commit() session.commit()
except IntegrityError: except IntegrityError:
self.session.rollback() session.rollback()
raise EntryNameNotUniqueError raise EntryNameNotUniqueError
except DataError: except DataError:
self.session.rollback() session.rollback()
raise StringTooLong raise StringTooLong
def query_by_filter(self, session, obj_type, * args):
def query_by_filter(self, obj_type, * args):
filter_value = self.combine_filter(args) filter_value = self.combine_filter(args)
return self.session.query(obj_type).filter(filter_value).all() return session.query(obj_type).filter(filter_value).all()
def delete_entry(self, obj_type, * args): def delete_entry(self, session, obj_type, * args):
filter_value = self.combine_filter(args) filter_value = self.combine_filter(args)
entry = self.session.query(obj_type).filter(filter_value) entry = session.query(obj_type).filter(filter_value)
if entry.first() is not None: if entry.first() is not None:
entry.delete() entry.delete()
self.session.commit() session.commit()
else: else:
raise DeleteEntryError raise DeleteEntryError
self.session.close() session.close()
def print_database(self, obj_type): def print_database(self, session, obj_type):
obj_list = self.session.query(obj_type).all() obj_list = session.query(obj_type).all()
s = '' s = ''

View File

@ -77,15 +77,16 @@ async def register(ctx):
You must do this before adding entries to the database. You must do this before adding entries to the database.
''' '''
session = database_interface.database.Session()
try: try:
player_name = get_nickname(ctx.message.author) player_name = get_nickname(ctx.message.author)
database_interface.add_player(player_name, ctx.message.author.id) database_interface.add_player(session, player_name, ctx.message.author.id)
except AttributeError: except AttributeError:
await bot.say('{}, run this command on 24CC whoever you are'.format(ctx.message.author.mention)) await bot.say('{}, run this command on 24CC whoever you are'.format(ctx.message.author.mention))
return return
except LocationInitError: except LocationInitError:
raise commands.UserInputError raise commands.UserInputError
session.close()
await bot.say('{}, you have been added to the database.'.format(ctx.message.author.mention)) await bot.say('{}, you have been added to the database.'.format(ctx.message.author.mention))
@ -95,13 +96,13 @@ async def addbase(ctx, x_pos: int, z_pos: int, * args):
Adds your base to the database. The name is optional. Adds your base to the database. The name is optional.
?addbase [X Coordinate] [Y Coordinate] [Z Coordinate] [Base Name] ?addbase [X Coordinate] [Y Coordinate] [Z Coordinate] [Base Name]
''' '''
session = database_interface.database.Session()
if len(args) > 0: if len(args) > 0:
name = ' '.join(args) name = ' '.join(args)
else: else:
name = '{}\'s_Base'.format(database_interface.find_player_by_discord_uuid(ctx.message.author.id).name) name = '{}\'s_Base'.format(database_interface.find_player_by_discord_uuid(session, ctx.message.author.id).name)
try: try:
base = database_interface.add_location(ctx.message.author.id, name, x_pos, z_pos) base = database_interface.add_location(session, ctx.message.author.id, name, x_pos, z_pos)
except LocationInitError: except LocationInitError:
raise commands.UserInputError raise commands.UserInputError
except EntryNameNotUniqueError: except EntryNameNotUniqueError:
@ -112,6 +113,8 @@ async def addbase(ctx, x_pos: int, z_pos: int, * args):
await bot.say('{}, your base named **{}** located at {} has been added' await bot.say('{}, your base named **{}** located at {} has been added'
' to the database.'.format(ctx.message.author.mention, base.name, base.pos_to_str())) ' to the database.'.format(ctx.message.author.mention, base.name, base.pos_to_str()))
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def addshop(ctx, x_pos: int, z_pos: int, *args): async def addshop(ctx, x_pos: int, z_pos: int, *args):
@ -119,14 +122,14 @@ async def addshop(ctx, x_pos: int, z_pos: int, *args):
Adds your shop to the database. The name is optional. Adds your shop to the database. The name is optional.
?addshop [X Coordinate] [Y Coordinate] [Z Coordinate] [Shop Name] ?addshop [X Coordinate] [Y Coordinate] [Z Coordinate] [Shop Name]
''' '''
session = database_interface.database.Session()
if len(args) > 0: if len(args) > 0:
name = ' '.join(args) name = ' '.join(args)
else: else:
name = '{}\'s_Shop'.format(database_interface.find_player_by_discord_uuid(ctx.message.author.id).name) name = '{}\'s_Shop'.format(database_interface.find_player_by_discord_uuid(session, ctx.message.author.id).name)
try: try:
shop = database_interface.add_shop(ctx.message.author.id, name, x_pos, z_pos) shop = database_interface.add_shop(session, ctx.message.author.id, name, x_pos, z_pos)
except LocationInitError: except LocationInitError:
raise commands.UserInputError raise commands.UserInputError
except EntryNameNotUniqueError: except EntryNameNotUniqueError:
@ -137,6 +140,8 @@ async def addshop(ctx, x_pos: int, z_pos: int, *args):
await bot.say('{}, your shop named **{}** located at {} has been added' await bot.say('{}, your shop named **{}** located at {} has been added'
' to the database.'.format(ctx.message.author.mention, shop.name, shop.pos_to_str())) ' to the database.'.format(ctx.message.author.mention, shop.name, shop.pos_to_str()))
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def tunnel(ctx, tunnel_color: str, tunnel_number: int, *args): async def tunnel(ctx, tunnel_color: str, tunnel_number: int, *args):
@ -145,14 +150,14 @@ async def tunnel(ctx, tunnel_color: str, tunnel_number: int, *args):
The location name is optional. If the location has a tunnel, it is updated. The location name is optional. If the location has a tunnel, it is updated.
?addtunnel [Tunnel Color] [Tunnel_Number] [Location Name] ?addtunnel [Tunnel Color] [Tunnel_Number] [Location Name]
''' '''
session = database_interface.database.Session()
try: try:
if len(args) == 0: if len(args) == 0:
location_name = None location_name = None
else: else:
location_name = name = ' '.join(args) location_name = ' '.join(args)
database_interface.add_tunnel(ctx.message.author.id, tunnel_color, tunnel_number, location_name) database_interface.add_tunnel(session, ctx.message.author.id, tunnel_color, tunnel_number, location_name)
except EntryNameNotUniqueError: except EntryNameNotUniqueError:
await bot.say('{}, you already have one tunnel in the database, please specify a location.'.format( await bot.say('{}, you already have one tunnel in the database, please specify a location.'.format(
ctx.message.author.mention)) ctx.message.author.mention))
@ -167,6 +172,8 @@ async def tunnel(ctx, tunnel_color: str, tunnel_number: int, *args):
await bot.say('{}, your tunnel has been added to the database'.format(ctx.message.author.mention)) await bot.say('{}, your tunnel has been added to the database'.format(ctx.message.author.mention))
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def find(ctx, search: str): async def find(ctx, search: str):
@ -174,14 +181,16 @@ async def find(ctx, search: str):
Finds all the locations and tunnels matching the search term Finds all the locations and tunnels matching the search term
?find [Search] ?find [Search]
''' '''
session = database_interface.database.Session()
try: try:
result = database_interface.search_all_fields(search) result = database_interface.search_all_fields(session, search)
await bot.say('{}, The following entries match **{}**:\n{}'.format(ctx.message.author.mention, search, result)) await bot.say('{}, The following entries match **{}**:\n{}'.format(ctx.message.author.mention, search, result))
except LocationLookUpError: except LocationLookUpError:
await bot.say('{}, no matches **{}** were found in the database'.format(ctx.message.author.mention, search)) await bot.say('{}, no matches **{}** were found in the database'.format(ctx.message.author.mention, search))
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def delete(ctx, * args): async def delete(ctx, * args):
@ -189,14 +198,16 @@ async def delete(ctx, * args):
Deletes a location from the database. Deletes a location from the database.
?delete [Location name] ?delete [Location name]
''' '''
session = database_interface.database.Session()
try: try:
name = ' '.join(args) name = ' '.join(args)
database_interface.delete_location(ctx.message.author.id, name) database_interface.delete_location(session, ctx.message.author.id, name)
await bot.say('{}, your location named **{}** has been deleted.'.format(ctx.message.author.mention, name)) await bot.say('{}, your location named **{}** has been deleted.'.format(ctx.message.author.mention, name))
except (DeleteEntryError, PlayerNotFound): except (DeleteEntryError, PlayerNotFound):
await bot.say('{}, you do not have a location named **{}**.'.format(ctx.message.author.mention, name)) await bot.say('{}, you do not have a location named **{}**.'.format(ctx.message.author.mention, name))
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def findaround(ctx, x_pos: int, z_pos: int, * args): async def findaround(ctx, x_pos: int, z_pos: int, * args):
@ -210,7 +221,7 @@ async def findaround(ctx, x_pos: int, z_pos: int, * args):
Optional Flags: Optional Flags:
-d [dimension] -d [dimension]
''' '''
session = database_interface.database.Session()
radius = 200 radius = 200
dimension = 'Overworld' dimension = 'Overworld'
@ -225,7 +236,7 @@ async def findaround(ctx, x_pos: int, z_pos: int, * args):
if args[1] == '-d': if args[1] == '-d':
dimension = args[2] dimension = args[2]
base_list = database_interface.find_location_around(x_pos, z_pos, radius, dimension) base_list = database_interface.find_location_around(session, x_pos, z_pos, radius, dimension)
if len(base_list) != 0: if len(base_list) != 0:
base_string = loc_list_to_string(base_list, '{} \n{}') base_string = loc_list_to_string(base_list, '{} \n{}')
@ -238,6 +249,8 @@ async def findaround(ctx, x_pos: int, z_pos: int, * args):
except ValueError: except ValueError:
raise commands.UserInputError raise commands.UserInputError
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def additem(ctx, item_name: str, quantity: int, diamond_price: int, * args): async def additem(ctx, item_name: str, quantity: int, diamond_price: int, * args):
@ -247,9 +260,9 @@ async def additem(ctx, item_name: str, quantity: int, diamond_price: int, * args
?additem [Item Name] [Quantity] [Price] [Shop name] ?additem [Item Name] [Quantity] [Price] [Shop name]
''' '''
session = database_interface.database.Session()
try: try:
shop_list = database_interface.find_shop_by_owner_uuid(ctx.message.author.id) shop_list = database_interface.find_shop_by_owner_uuid(session, ctx.message.author.id)
if len(shop_list) == 1: if len(shop_list) == 1:
shop_name = shop_list[0].name shop_name = shop_list[0].name
@ -259,7 +272,7 @@ async def additem(ctx, item_name: str, quantity: int, diamond_price: int, * args
else: else:
shop_name = ' '.join(args) shop_name = ' '.join(args)
database_interface.add_item(ctx.message.author.id, shop_name, item_name, diamond_price, quantity) database_interface.add_item(session, ctx.message.author.id, shop_name, item_name, diamond_price, quantity)
await bot.say('{}, **{}** has been added to the inventory of **{}**.'.format(ctx.message.author.mention, await bot.say('{}, **{}** has been added to the inventory of **{}**.'.format(ctx.message.author.mention,
item_name, shop_name)) item_name, shop_name))
except PlayerNotFound: except PlayerNotFound:
@ -271,6 +284,8 @@ async def additem(ctx, item_name: str, quantity: int, diamond_price: int, * args
await bot.say('{}, you don\'t have any shops named **{}** in the database.'.format(ctx.message.author.mention, await bot.say('{}, you don\'t have any shops named **{}** in the database.'.format(ctx.message.author.mention,
shop_name)) shop_name))
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def selling(ctx, item_name: str): async def selling(ctx, item_name: str):
@ -279,12 +294,15 @@ async def selling(ctx, item_name: str):
?selling [item] ?selling [item]
''' '''
shop_list = database_interface.find_shop_selling_item(item_name) session = database_interface.database.Session()
shop_list = database_interface.find_shop_selling_item(session, item_name)
shop_list_str = loc_list_to_string(shop_list) shop_list_str = loc_list_to_string(shop_list)
await bot.say('{}, the following shops sell **{}**: \n{}'.format(ctx.message.author.mention, item_name, await bot.say('{}, the following shops sell **{}**: \n{}'.format(ctx.message.author.mention, item_name,
shop_list_str)) shop_list_str))
session.close()
@bot.command(pass_context=True) @bot.command(pass_context=True)
async def info(ctx, * args): async def info(ctx, * args):
@ -295,15 +313,18 @@ async def info(ctx, * args):
?info [Location Name] ?info [Location Name]
''' '''
session = database_interface.database.Session()
try: try:
name = ' '.join(args) name = ' '.join(args)
loc = database_interface.find_location_by_name(name)[0] loc = database_interface.find_location_by_name(session, name)[0]
except IndexError: except IndexError:
await bot.say('{}, no locations in the database match {}.'.format(ctx.message.author.mention, name)) await bot.say('{}, no locations in the database match {}.'.format(ctx.message.author.mention, name))
return return
await bot.say('{}'.format(loc.full_str())) await bot.say('{}'.format(loc.full_str()))
session.close()
# Helper Functions ************************************************************ # Helper Functions ************************************************************
def get_nickname(discord_user): def get_nickname(discord_user):

View File

@ -9,48 +9,49 @@ class TestGeoffreyDatabase(TestCase):
config = read_config() config = read_config()
engine_arg = config['SQL']['test_args'] engine_arg = config['SQL']['test_args']
self.interface = DiscordDatabaseInterface(engine_arg) self.interface = DiscordDatabaseInterface(engine_arg)
self.session = self.interface.database.Session()
self.owner = Player('ZeroHD', '143072699567177728') self.owner = Player('ZeroHD', '143072699567177728')
self.loc = Location('test', 1, 3, self.owner, dimension='Nether') self.loc = Location('test', 1, 3, self.owner, dimension='Nether')
self.tunnel = Tunnel(self.owner, 'Green', 105, self.loc) self.tunnel = Tunnel(self.owner, 'Green', 105, self.loc)
def tearDown(self): def tearDown(self):
self.interface.database.clear_all() self.interface.database.clear_all(self.session)
self.session.close()
def test_add_object(self): def test_add_object(self):
self.interface.database.add_object(self.loc) self.interface.database.add_object(self.session, self.loc)
self.interface.database.add_object(self.owner) self.interface.database.add_object(self.session, self.owner)
self.interface.database.add_object(self.tunnel) self.interface.database.add_object(self.session, self.tunnel)
uuid = grab_UUID('ZeroHD') uuid = grab_UUID('ZeroHD')
expr = Player.mc_uuid == uuid expr = Player.mc_uuid == uuid
p = self.interface.database.query_by_filter(Player, expr)[0] p = self.interface.database.query_by_filter(self.session, Player, expr)[0]
expr = Location.owner == p expr = Location.owner == p
loc2 = self.interface.database.query_by_filter(Location, expr)[0] loc2 = self.interface.database.query_by_filter(self.session, Location, expr)[0]
self.assertEqual(self.loc.id, loc2.id) self.assertEqual(self.loc.id, loc2.id)
def test_query_by_filter(self): def test_query_by_filter(self):
self.interface.database.add_object(self.loc) self.interface.database.add_object(self.session, self.loc)
self.interface.database.add_object(self.owner) self.interface.database.add_object(self.session, self.owner)
expr = (Location.owner == self.owner) expr = (Location.owner == self.owner)
loc2 = self.interface.database.query_by_filter(Location, expr)[0] loc2 = self.interface.database.query_by_filter(self.session, Location, expr)[0]
self.assertEqual(loc2.id, self.loc.id) self.assertEqual(loc2.id, self.loc.id)
def test_delete_entry(self): def test_delete_entry(self):
self.interface.database.add_object(self.loc) self.interface.database.add_object(self.session, self.loc)
self.interface.database.add_object(self.owner) self.interface.database.add_object(self.session, self.owner)
expr = Location.owner == self.owner expr = Location.owner == self.owner
self.interface.database.delete_entry(Location, expr) self.interface.database.delete_entry(self.session, Location, expr)
loc2 = self.interface.database.query_by_filter(Location, expr) loc2 = self.interface.database.query_by_filter(self.session, Location, expr)
self.assertEqual(len(loc2), 0) self.assertEqual(len(loc2), 0)
self.assertRaises(DeleteEntryError, self.interface.database.delete_entry, Location, expr) self.assertRaises(DeleteEntryError, self.interface.database.delete_entry, self.session, Location, expr)
def test_add_shop(self): def test_add_shop(self):
owner = self.add_player() owner = self.add_player()
@ -58,47 +59,47 @@ class TestGeoffreyDatabase(TestCase):
self.assertEqual(type(shop), Shop) self.assertEqual(type(shop), Shop)
shop_list = self.interface.find_shop_by_name('test') shop_list = self.interface.find_shop_by_name(self.session, 'test')
self.assertEqual(shop_list[0].dimension, shop.dimension) self.assertEqual(shop_list[0].dimension, shop.dimension)
def add_shop(self): def add_shop(self):
return self.interface.add_shop('143072699567177728', 'test', 1, 3, "nether") return self.interface.add_shop(self.session, '143072699567177728', 'test', 1, 3, "nether")
def add_player(self): def add_player(self):
return self.interface.add_player('ZeroHD', '143072699567177728') return self.interface.add_player(self.session, 'ZeroHD', '143072699567177728')
def add_loc(self): def add_loc(self):
return self.interface.add_location('143072699567177728', 'test', 0, 0) return self.interface.add_location(self.session, '143072699567177728', 'test', 0, 0)
def test_add_two_shops(self): def test_add_two_shops(self):
owner = self.add_player() owner = self.add_player()
shop1 = self.add_shop() shop1 = self.add_shop()
shop2 = self.interface.add_shop('143072699567177728', 'no u', 1, 3) shop2 = self.interface.add_shop(self.session, '143072699567177728', 'no u', 1, 3)
loc_list = self.interface.find_location_by_owner_uuid('143072699567177728') loc_list = self.interface.find_location_by_owner_uuid(self.session, '143072699567177728')
self.assertEqual(loc_list[1].id, shop2.id) self.assertEqual(loc_list[1].id, shop2.id)
def test_add_tunnel(self): def test_add_tunnel(self):
self.add_player() self.add_player()
tunnel1 = self.interface.add_tunnel('143072699567177728', 'green', 155, None) tunnel1 = self.interface.add_tunnel(self.session, '143072699567177728', 'green', 155, None)
tunnel2 = self.interface.find_tunnel_by_owner_name('ZeroHD')[0] tunnel2 = self.interface.find_tunnel_by_owner_name(self.session, 'ZeroHD')[0]
self.assertEqual(tunnel1, tunnel2) self.assertEqual(tunnel1, tunnel2)
def test_add_item(self): def test_add_item(self):
owner = self.add_player() owner = self.add_player()
self.add_shop() self.add_shop()
self.interface.add_item('143072699567177728', 'test', 'dirt', 1, 15) self.interface.add_item(self.session, '143072699567177728', 'test', 'dirt', 1, 15)
shops = self.interface.find_shop_selling_item('dirt') shops = self.interface.find_shop_selling_item(self.session, 'dirt')
self.assertGreater(len(shops), 0) self.assertGreater(len(shops), 0)
def test_find_location_by_owner(self): def test_find_location_by_owner(self):
owner = self.add_player() owner = self.add_player()
shop = self.add_shop() shop = self.add_shop()
loc_list = self.interface.find_location_by_owner(owner) loc_list = self.interface.find_location_by_owner(self.session, owner)
self.assertEqual(loc_list[0].id, shop.id) self.assertEqual(loc_list[0].id, shop.id)
@ -106,7 +107,7 @@ class TestGeoffreyDatabase(TestCase):
owner = self.add_player() owner = self.add_player()
shop = self.add_shop() shop = self.add_shop()
loc_list = self.interface.find_location_by_name_and_owner_uuid('143072699567177728', 'test') loc_list = self.interface.find_location_by_name_and_owner_uuid(self.session, '143072699567177728', 'test')
self.assertEqual(loc_list[0].id, shop.id) self.assertEqual(loc_list[0].id, shop.id)
@ -114,9 +115,9 @@ class TestGeoffreyDatabase(TestCase):
owner = self.add_player() owner = self.add_player()
self.add_loc() self.add_loc()
self.interface.delete_location('143072699567177728', 'test') self.interface.delete_location(self.session, '143072699567177728', 'test')
loc_list = self.interface.find_location_by_name_and_owner_uuid('143072699567177728', 'test') loc_list = self.interface.find_location_by_name_and_owner_uuid(self.session, '143072699567177728', 'test')
self.assertEqual(len(loc_list), 0) self.assertEqual(len(loc_list), 0)
@ -126,27 +127,27 @@ class TestGeoffreyDatabase(TestCase):
dim = "o" dim = "o"
loc_list = self.interface.find_location_around(100, 100, 100, dim) loc_list = self.interface.find_location_around(self.session, 100, 100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(200, 200, 100, dim) loc_list = self.interface.find_location_around(self.session, 200, 200, 100, dim)
self.assertEqual(len(loc_list), 0) self.assertEqual(len(loc_list), 0)
loc_list = self.interface.find_location_around(-100, -100, 100, dim) loc_list = self.interface.find_location_around(self.session, -100, -100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(100, -100, 100, dim) loc_list = self.interface.find_location_around(self.session, 100, -100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(-100, 100, 100, dim) loc_list = self.interface.find_location_around(self.session, -100, 100, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
loc_list = self.interface.find_location_around(50, -50, 100, dim) loc_list = self.interface.find_location_around(self.session, 50, -50, 100, dim)
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
@ -154,7 +155,7 @@ class TestGeoffreyDatabase(TestCase):
owner = self.add_player() owner = self.add_player()
loc = self.add_loc() loc = self.add_loc()
loc_list = self.interface.find_location_by_name('test') loc_list = self.interface.find_location_by_name(self.session, 'test')
self.assertEqual(loc_list[0].name, loc.name) self.assertEqual(loc_list[0].name, loc.name)
@ -162,7 +163,7 @@ class TestGeoffreyDatabase(TestCase):
owner = self.add_player() owner = self.add_player()
loc = self.add_loc() loc = self.add_loc()
loc_list = self.interface.search_all_fields('ZeroHD') loc_list = self.interface.search_all_fields(self.session, 'ZeroHD')
self.assertEqual(type(loc_list), str) self.assertEqual(type(loc_list), str)
@ -170,15 +171,15 @@ class TestGeoffreyDatabase(TestCase):
owner = self.add_player() owner = self.add_player()
loc = self.add_loc() loc = self.add_loc()
loc_list = self.interface.find_location_by_owner_name('zerohd') loc_list = self.interface.find_location_by_owner_name(self.session, 'zerohd')
self.assertEqual(loc_list[0].id, loc.id) self.assertEqual(loc_list[0].id, loc.id)
self.interface.add_shop('143072699567177728', 'testshop', 1, 3, 'neThEr') self.interface.add_shop(self.session, '143072699567177728', 'testshop', 1, 3, 'neThEr')
self.interface.add_item('143072699567177728', 'testshop', 'dirts', 1, 15) self.interface.add_item(self.session, '143072699567177728', 'testshop', 'dirts', 1, 15)
shops = self.interface.find_shop_selling_item('diRt') shops = self.interface.find_shop_selling_item(self.session, 'diRt')
self.assertGreater(len(shops), 0) self.assertGreater(len(shops), 0)
@ -186,14 +187,14 @@ class TestGeoffreyDatabase(TestCase):
#self.assertEqual(shops[0].name, 'testshop') #self.assertEqual(shops[0].name, 'testshop')
loc_list = self.interface.find_location_by_name('TEST') loc_list = self.interface.find_location_by_name(self.session, 'TEST')
self.assertEqual(loc_list[0].name, 'test') self.assertEqual(loc_list[0].name, 'test')
def test_big_input(self): def test_big_input(self):
owner = self.add_player() owner = self.add_player()
self.assertRaises(StringTooLong, self.interface.add_location, '143072699567177728', self.assertRaises(StringTooLong, self.interface.add_location, self.session,'143072699567177728',
'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT' 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT'
'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT' 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT'
'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT', 0, 0,) 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT', 0, 0,)
@ -202,18 +203,18 @@ class TestGeoffreyDatabase(TestCase):
self.add_player() self.add_player()
self.add_loc() self.add_loc()
self.assertRaises(EntryNameNotUniqueError, self.interface.add_location, self.assertRaises(EntryNameNotUniqueError, self.interface.add_location, self.session,
'143072699567177728', 'test', 0, 0, 0) '143072699567177728', 'test', 0, 0, 0)
def test_delete_parent(self): def test_delete_parent(self):
owner = self.add_player() owner = self.add_player()
loc = self.add_shop() loc = self.add_shop()
self.interface.add_item('143072699567177728', 'test', 'dirt', 1, 15) self.interface.add_item(self.session, '143072699567177728', 'test', 'dirt', 1, 15)
self.interface.delete_location('143072699567177728', 'test') self.interface.delete_location(self.session, '143072699567177728', 'test')
shops = self.interface.find_shop_selling_item('dirt') shops = self.interface.find_shop_selling_item(self.session, 'dirt')
self.assertEqual(len(shops), 0) self.assertEqual(len(shops), 0)