Refactored DatabaseInterface to move all location addition and searching into one set of functions. The type of location is then specified as a parameter.

doc_update
Joey Hines 2018-08-30 12:27:05 -05:00
parent 2111bf8c8e
commit 0d26818fa6
4 changed files with 46 additions and 54 deletions

View File

@ -31,7 +31,9 @@ def create_config(config, path):
def read_config(path): def read_config(path):
config = configparser.ConfigParser() config = configparser.ConfigParser()
try: try:
config.read_file(codecs.open(path, "r", "utf8")) file = codecs.open(path, "r", "utf8")
config.read_file(file)
file.close()
except FileNotFoundError: except FileNotFoundError:
create_config(config, path) create_config(config, path)
print("Config generated.") print("Config generated.")

View File

@ -40,13 +40,13 @@ class Commands:
try: try:
player = self.get_player(session, discord_uuid, mc_uuid) player = self.get_player(session, discord_uuid, mc_uuid)
if len(self.interface.find_location_by_owner(session, player)) == 0: if len(self.interface.find_location_by_owner(session, player, loc_type=Base)) == 0:
if base_name is None: if base_name is None:
base_name = "{}'s Base".format(player.name) base_name = "{}'s Base".format(player.name)
elif base_name is None: elif base_name is None:
raise EntryNameNotUniqueError raise EntryNameNotUniqueError
base = self.interface.add_base(session, player, base_name, x_pos, z_pos) base = self.interface.add_loc(session, player, base_name, x_pos, z_pos, loc_type=Base)
base_str = base.__str__() base_str = base.__str__()
finally: finally:
@ -60,13 +60,13 @@ class Commands:
try: try:
player = self.get_player(session, discord_uuid, mc_uuid) player = self.get_player(session, discord_uuid, mc_uuid)
if len(self.interface.find_shop_by_owner(session, player)) == 0: if len(self.interface.find_location_by_owner(session, player, loc_type=Shop)) == 0:
if shop_name is None: if shop_name is None:
shop_name = "{}'s Shop".format(player.name) shop_name = "{}'s Shop".format(player.name)
elif shop_name is None: elif shop_name is None:
raise EntryNameNotUniqueError raise EntryNameNotUniqueError
shop = self.interface.add_shop(session, player, shop_name, x_pos, z_pos) shop = self.interface.add_loc(session, player, shop_name, x_pos, z_pos, loc_type=Shop)
shop_name = shop.__str__() shop_name = shop.__str__()
finally: finally:
@ -129,7 +129,7 @@ class Commands:
session = self.interface.database.Session() session = self.interface.database.Session()
try: try:
player = self.get_player(session, discord_uuid, mc_uuid) player = self.get_player(session, discord_uuid, mc_uuid)
shop_list = self.interface.find_shop_by_owner(session, player) shop_list = self.interface.find_location_by_owner(session, player, loc_type=Shop)
if shop_name is None: if shop_name is None:
if len(shop_list) == 1: if len(shop_list) == 1:
@ -291,7 +291,7 @@ class Commands:
player = self.get_player(session, discord_uuid=discord_uuid, mc_uuid=mc_uuid) player = self.get_player(session, discord_uuid=discord_uuid, mc_uuid=mc_uuid)
if shop_name is None: if shop_name is None:
shop_list = self.interface.find_shop_by_owner(session, player) shop_list = self.interface.find_location_by_owner(session, player, loc_type=Shop)
if len(shop_list) == 0: if len(shop_list) == 0:
raise LocationLookUpError raise LocationLookUpError
@ -301,7 +301,7 @@ class Commands:
shop = shop_list[0] shop = shop_list[0]
else: else:
shop = self.interface.find_shop_by_name_and_owner(session, player, shop_name)[0] shop = self.interface.find_location_by_name_and_owner(session, player, shop_name, loc_type=Shop)[0]
expr = (ItemListing.name == item) & (ItemListing.shop == shop) expr = (ItemListing.name == item) & (ItemListing.shop == shop)
self.interface.database.delete_entry(session, ItemListing, expr) self.interface.database.delete_entry(session, ItemListing, expr)

View File

@ -6,15 +6,16 @@ class DatabaseInterface:
def __init__(self, bot_config, debug=False): def __init__(self, bot_config, debug=False):
self.database = GeoffreyDatabase(bot_config, debug) self.database = GeoffreyDatabase(bot_config, debug)
def add_base(self, session, owner, name, x_pos, z_pos, dimension=None): def add_loc(self, session, owner, name, x_pos, z_pos, dimension=None, loc_type=Location):
base = Base(name, x_pos, z_pos, owner, dimension) if loc_type == Base:
self.database.add_object(session, base) loc = Base(name, x_pos, z_pos, owner, dimension)
return base elif loc_type == Shop:
loc = Shop(name, x_pos, z_pos, owner, dimension)
else:
loc = Location(name, x_pos, z_pos, owner, dimension)
def add_shop(self, session, owner, name, x_pos, z_pos, dimension=None): self.database.add_object(session, loc)
shop = Shop(name, x_pos, z_pos, owner, dimension) return loc
self.database.add_object(session, shop)
return shop
def add_tunnel(self, session, owner, color, number, location_name): def add_tunnel(self, session, owner, color, number, location_name):
tunnels = self.find_tunnel_by_owner(session, owner) tunnels = self.find_tunnel_by_owner(session, owner)
@ -40,7 +41,7 @@ class DatabaseInterface:
def add_item(self, session, owner, shop_name, item_name, price, amount): def add_item(self, session, owner, shop_name, item_name, price, amount):
try: try:
shop = self.find_shop_by_name_and_owner(session, owner, shop_name) shop = self.find_location_by_name_and_owner(session, owner, shop_name, loc_type=Shop)
item = ItemListing(item_name, price, amount, shop[0]) item = ItemListing(item_name, price, amount, shop[0])
self.database.add_object(session, item) self.database.add_object(session, item)
@ -63,39 +64,27 @@ class DatabaseInterface:
return player return player
def find_location_by_name(self, session, name): def find_location_by_name(self, session, name, loc_type=Location):
expr = Location.name.ilike('%{}%'.format(name)) expr = loc_type.name.ilike('%{}%'.format(name))
return self.database.query_by_filter(session, Location, expr) return self.database.query_by_filter(session, loc_type, expr)
def find_shop_by_name(self, session, name): def find_location_by_owner(self, session, owner, loc_type=Location):
expr = Location.name.ilike('%{}%'.format(name)) expr = loc_type.owner == owner
return self.database.query_by_filter(session, Shop, expr) return self.database.query_by_filter(session, loc_type, expr)
def find_location_by_owner(self, session, owner): def find_location_by_owner_name(self, session, owner_name, loc_type=Location):
expr = Location.owner == owner expr = loc_type.owner.has(Player.name.ilike(owner_name))
return self.database.query_by_filter(session, Location, expr) return self.database.query_by_filter(session, loc_type, expr)
def find_shop_by_owner(self, session, owner): def find_location_by_name_and_owner(self, session, owner, name, loc_type=Location):
expr = Shop.owner == owner expr = (loc_type.owner == owner) & (loc_type.name.ilike(name))
return self.database.query_by_filter(session, Shop, expr) return self.database.query_by_filter(session, loc_type, expr)
def find_location_by_owner_name(self, session, owner_name): def find_location_around(self, session, x_pos, z_pos, radius, dimension, loc_type=Location):
expr = Location.owner.has(Player.name.ilike(owner_name))
return self.database.query_by_filter(session, Location, expr)
def find_shop_by_name_and_owner(self, session, owner, name):
expr = (Shop.owner == owner) & (Shop.name.ilike(name))
return self.database.query_by_filter(session, Shop, expr)
def find_location_by_name_and_owner(self, session, owner, name):
expr = (Location.owner == owner) & (Location.name.ilike(name))
return self.database.query_by_filter(session, Location, expr)
def find_location_around(self, session, x_pos, z_pos, radius, dimension):
dimension_obj = Dimension.str_to_dimension(dimension) dimension_obj = Dimension.str_to_dimension(dimension)
expr = (Location.x < x_pos + radius + 1) & (Location.x > x_pos - radius - 1) & \ expr = (loc_type.x < x_pos + radius + 1) & (loc_type.x > x_pos - radius - 1) & \
(Location.z < z_pos + radius + 1) \ (loc_type.z < z_pos + radius + 1) \
& (Location.z > z_pos - radius - 1) & (Location.dimension == dimension_obj) & (loc_type.z > z_pos - radius - 1) & (loc_type.dimension == dimension_obj)
return list_to_string(self.database.query_by_filter(session, Location, expr)) return list_to_string(self.database.query_by_filter(session, Location, expr))

View File

@ -4,6 +4,7 @@ from unittest import TestCase
from DatabaseInterface import * from DatabaseInterface import *
from BotConfig import * from BotConfig import *
class TestGeoffreyDatabase(TestCase): class TestGeoffreyDatabase(TestCase):
def setUp(self): def setUp(self):
path = os.path.dirname(os.path.abspath(__file__)) path = os.path.dirname(os.path.abspath(__file__))
@ -21,7 +22,7 @@ class TestGeoffreyDatabase(TestCase):
self.session.close() self.session.close()
def add_shop(self, player): def add_shop(self, player):
shop = self.interface.add_shop(self.session, player, 'test', 1, 3, "nether") shop = self.interface.add_loc(self.session, player, 'test', 1, 3, "nether", loc_type=Shop)
return shop return shop
def add_player(self): def add_player(self):
@ -29,7 +30,7 @@ class TestGeoffreyDatabase(TestCase):
return player return player
def add_loc(self, player): def add_loc(self, player):
loc = self.interface.add_base(self.session, player, 'test', 0, 0) loc = self.interface.add_loc(self.session, player, 'test', 0, 0, loc_type=Base)
return loc return loc
def test_add_object(self): def test_add_object(self):
@ -79,13 +80,13 @@ class TestGeoffreyDatabase(TestCase):
self.assertEqual(type(shop), Shop) self.assertEqual(type(shop), Shop)
shop_list = self.interface.find_shop_by_name(self.session, 'test') shop_list = self.interface.find_location_by_name(self.session, 'test', loc_type=Shop)
self.assertEqual(shop_list[0].dimension, shop.dimension) self.assertEqual(shop_list[0].dimension, shop.dimension)
def test_add_two_shops(self): def test_add_two_shops(self):
owner = self.add_player() owner = self.add_player()
self.add_shop(owner) self.add_shop(owner)
shop2 = self.interface.add_shop(self.session, owner, 'no u', 1, 3) shop2 = self.interface.add_loc(self.session, owner, 'no u', 1, 3, loc_type=Shop)
loc_list = self.interface.find_location_by_owner(self.session, owner) loc_list = self.interface.find_location_by_owner(self.session, owner)
@ -186,7 +187,7 @@ class TestGeoffreyDatabase(TestCase):
self.assertEqual(loc_list[0].id, loc.id) self.assertEqual(loc_list[0].id, loc.id)
self.interface.add_shop(self.session, owner, 'testshop', 1, 3, 'neThEr') self.interface.add_loc(self.session, owner, 'testshop', 1, 3, 'neThEr', loc_type=Shop)
self.interface.add_item(self.session, owner, 'testshop', 'dirts', 1, 15) self.interface.add_item(self.session, owner, 'testshop', 'dirts', 1, 15)
@ -201,17 +202,17 @@ class TestGeoffreyDatabase(TestCase):
def test_big_input(self): def test_big_input(self):
owner = self.add_player() owner = self.add_player()
self.assertRaises(DatabaseValueError, self.interface.add_base, self.session, owner, self.assertRaises(DatabaseValueError, self.interface.add_loc, self.session, owner,
'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT' 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT'
'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT' 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT'
'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT', 0, 0, ) 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT', 0, 0, Shop)
def test_duplicate_name(self): def test_duplicate_name(self):
owner = self.add_player() owner = self.add_player()
self.add_loc(owner) self.add_loc(owner)
self.assertRaises(EntryNameNotUniqueError, self.interface.add_base, self.session, self.assertRaises(EntryNameNotUniqueError, self.interface.add_loc, self.session,
owner, 'test', 0, 0, 0) owner, 'test', 0, 0, Shop)
def test_delete_parent(self): def test_delete_parent(self):
owner = self.add_player() owner = self.add_player()