diff --git a/BotConfig.py b/BotConfig.py new file mode 100644 index 0000000..2f5f2ed --- /dev/null +++ b/BotConfig.py @@ -0,0 +1,34 @@ +import configparser + +def read_config(): + config = configparser.ConfigParser() + config.read('GeoffreyConfig.ini') + + if len(config.sections()) == 0: + create_config(config) + print("GeoffreyConfig.ini generated.") + quit(0) + + return config + + +def create_config(config): + config['Discord'] = {'Token': ''} + config['SQL'] = {'Dialect+Driver': '', 'username': '', 'password':'', 'host': '', 'port': '', 'database':'', + 'test_args':''} + + with open('GeoffreyConfig.ini', 'w') as configfile: + config.write(configfile) + + +def get_engine_arg(config): + driver = config['SQL']['Dialect+Driver'] + username = config['SQL']['username'] + password = config['SQL']['password'] + host = config['SQL']['host'] + port = config['SQL']['port'] + database_name = config['SQL']['database'] + + engine_args = '{}://{}:{}@{}:{}/{}' + + return engine_args.format(driver, username, password, host, port, database_name) \ No newline at end of file diff --git a/BotErrors.py b/BotErrors.py index 730b226..e23661e 100644 --- a/BotErrors.py +++ b/BotErrors.py @@ -24,4 +24,8 @@ class PlayerNotFound(DataBaseError): class EntryNameNotUniqueError(DataBaseError): '''A location by that name is already in the database.''' +class StringTooLong(DataBaseError): + '''Given string is too long.''' + + diff --git a/DatabaseModels.py b/DatabaseModels.py index 5e8f1f7..3d0e70e 100644 --- a/DatabaseModels.py +++ b/DatabaseModels.py @@ -1,13 +1,12 @@ -from sqlalchemy import Column, Integer, String, ForeignKey, Enum -import enum -from sqlalchemy.ext.declarative import declarative_base -from BotErrors import * -from sqlalchemy import create_engine, exists, literal +from sqlalchemy import Column, Integer, String, ForeignKey, Enum, create_engine, exists, MetaData from sqlalchemy.orm import sessionmaker, relationship, column_property -from sqlalchemy.exc import IntegrityError -import sqlalchemy -from MinecraftAccountInfoGrabber import * +from sqlalchemy.exc import IntegrityError, DataError +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.sql import expression from difflib import SequenceMatcher +import enum +from BotErrors import * +from MinecraftAccountInfoGrabber import * SQL_Base = declarative_base() @@ -184,7 +183,7 @@ class DatabaseInterface: def delete_location(self, owner, name): expr = (Location.owner == owner) & (Location.name == name) - self.database.delete_entry(Shop, expr) + self.database.delete_entry(Location, expr) class DiscordDatabaseInterface(DatabaseInterface): @@ -243,6 +242,7 @@ class GeoffreyDatabase: self.engine = create_engine(engine_arg, echo=True) Session = sessionmaker(bind=self.engine) self.session = Session() + self.meta = MetaData() SQL_Base.metadata.create_all(self.engine) def add_object(self, obj): @@ -252,7 +252,12 @@ class GeoffreyDatabase: self.session.add(obj) self.session.commit() except IntegrityError: + self.session.rollback() raise EntryNameNotUniqueError + except DataError: + self.session.rollback() + raise StringTooLong + def query_by_filter(self, obj_type, * args): filter_value = self.combine_filter(args) @@ -264,6 +269,7 @@ class GeoffreyDatabase: if entry.first() is not None: entry.delete() + self.session.commit() else: raise DeleteEntryError @@ -277,7 +283,7 @@ class GeoffreyDatabase: return s def combine_filter(self, filter_value): - return sqlalchemy.sql.expression.and_(filter_value[0]) + return expression.and_(filter_value[0]) class TunnelDirection(enum.Enum): @@ -321,11 +327,12 @@ class Dimension(enum.Enum): class Player(SQL_Base): __tablename__ = 'Players' id = Column(Integer, primary_key=True, autoincrement=True) - mc_uuid = Column(String) - discord_uuid = Column(String) - name = Column(String) + mc_uuid = Column(String(128)) + discord_uuid = Column(String(128)) + name = Column(String(128)) locations = relationship("Location", back_populates="owner", lazy='dynamic', cascade="save-update, merge, delete, delete-orphan") + tunnels = relationship("Tunnel", back_populates="owner", lazy='dynamic', cascade="save-update, merge, delete, delete-orphan") @@ -342,8 +349,8 @@ class Tunnel(SQL_Base): tunnel_direction = Column(Enum(TunnelDirection)) owner_id = Column(Integer, ForeignKey('Players.id')) owner = relationship("Player", back_populates="tunnels", cascade="save-update, merge, delete") - location_id = Column(Integer, ForeignKey('Locations.id')) - location = relationship("Location", back_populates="tunnel", cascade="save-update, merge, delete") + location_id = Column(Integer, ForeignKey('Locations.id', ondelete='CASCADE')) + location = relationship("Location", back_populates="tunnel") def __init__(self, owner, tunnel_color, tunnel_number, location=None): try: @@ -362,18 +369,16 @@ class Location(SQL_Base): __tablename__ = 'Locations' id = Column(Integer, primary_key=True) - name = Column(String, unique=True) + name = Column(String(128), unique=True, ) x = Column(Integer) - y = Column(Integer) z = Column(Integer) - tunnel = relationship("Tunnel", back_populates="location", uselist=False, - cascade="save-update, merge, delete, delete-orphan") + tunnel = relationship("Tunnel", uselist=False, cascade="all, delete-orphan") dimension = Column(Enum(Dimension)) - owner_id = Column(Integer, ForeignKey('Players.id')) - owner = relationship("Player", back_populates="locations", cascade="save-update, merge, delete") - type = Column(String) + owner_id = Column(Integer, ForeignKey('Players.id', ondelete='CASCADE')) + owner = relationship("Player", back_populates="locations", cascade="all, delete-orphan", single_parent=True) + type = Column(String(128)) __mapper_args__ = { 'polymorphic_on': type, @@ -413,15 +418,13 @@ class Location(SQL_Base): class Shop(Location): __tablename__ = 'Shops' - shop_id = Column(Integer, ForeignKey('Locations.id'), primary_key=True) - name = column_property(Column(String), Location.name) + shop_id = Column(Integer, ForeignKey('Locations.id', ondelete='CASCADE'), primary_key=True) + name = column_property(Column(String(128)), Location.name) inventory = relationship('ItemListing', back_populates='shop', cascade='all, delete-orphan') __mapper_args__ = { 'polymorphic_identity': 'Shop', } - column_property() - def inv_to_str(self): if len(self.inventory.all()) != 0: @@ -449,12 +452,12 @@ class ItemListing(SQL_Base): __tablename__ = 'Items' id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String) + name = Column(String(128)) price = Column(Integer) amount = Column(Integer) - shop_id = Column(Integer, ForeignKey('Shops.shop_id')) - shop = relationship("Shop", back_populates="inventory") + shop_id = Column(Integer, ForeignKey('Shops.shop_id', ondelete='CASCADE')) + shop = relationship("Shop", back_populates="inventory", single_parent=True) def __init__(self, name, price, amount, shop): self.name = name diff --git a/Geoffrey.py b/Geoffrey.py index 08ca296..014b894 100644 --- a/Geoffrey.py +++ b/Geoffrey.py @@ -3,7 +3,8 @@ from DatabaseModels import * from BotErrors import * from MinecraftAccountInfoGrabber import * from itertools import zip_longest -import configparser +from BotConfig import * + import shlex #from WebInterface import * @@ -313,46 +314,15 @@ def get_args_dict(args): return {} -def create_config(): - config['Discord'] = {'Token': ''} - config['SQL'] = {'Dialect+Driver': 'Test', 'username': '', 'password':'', 'host': '', 'port': '', 'database':''} - - with open('GeoffreyConfig.ini', 'w') as configfile: - config.write(configfile) - -def get_engine_arg(config): - driver = config['SQL']['Dialect+Driver'] - username = config['SQL']['username'] - password = config['SQL']['password'] - host = config['SQL']['host'] - port = config['SQL']['port'] - database_name = config['SQL']['database'] - - engine_args = '{}://{}:{}@{}:{}/{}' - - return engine_args.format(driver, username, password, host, port, database_name) - - # Bot Startup ****************************************************************** +config = read_config() -config = configparser.ConfigParser() -config.read('GeoffreyConfig.ini') +TOKEN = config['Discord']['Token'] -if len(config.sections()) == 0: - create_config() - print("GeoffreyConfig.ini generated.") - quit(0) -else: - TOKEN = config['Discord']['Token'] +engine_arg = get_engine_arg(config) - if config['SQL']['dialect+driver'] == 'Test': - engine_arg = 'sqlite:///temp.db' - else: - engine_arg = get_engine_arg(config) +database_interface = DiscordDatabaseInterface(engine_arg) - database_interface = DiscordDatabaseInterface(engine_arg) - #WebInterface('127.0.0.1', 8081, database) - - bot.run(TOKEN) +bot.run(TOKEN) diff --git a/test_geoffreyDatabase.py b/test_geoffreyDatabase.py index e66e3ae..3f97283 100644 --- a/test_geoffreyDatabase.py +++ b/test_geoffreyDatabase.py @@ -2,14 +2,28 @@ from unittest import TestCase from DatabaseModels import * from BotErrors import * from MinecraftAccountInfoGrabber import * +from BotConfig import * class TestGeoffreyDatabase(TestCase): def setUp(self): - self.interface = DiscordDatabaseInterface('sqlite:///:memory:') + config = read_config() + + engine_arg = config['SQL']['test_args'] + + self.interface = DiscordDatabaseInterface(engine_arg) self.owner = Player('ZeroHD', '143072699567177728') self.loc = Location('test', 1, 3, self.owner, dimension='Nether') self.tunnel = Tunnel(self.owner, 'Green', 105, self.loc) + def tearDown(self): + self.interface.database.session.query(Tunnel).delete() + self.interface.database.session.query(ItemListing).delete() + self.interface.database.session.query(Shop).delete() + self.interface.database.session.query(Location).delete() + self.interface.database.session.query(Player).delete() + + self.interface.database.session.commit() + def test_add_object(self): self.interface.database.add_object(self.loc) self.interface.database.add_object(self.owner) @@ -184,12 +198,11 @@ class TestGeoffreyDatabase(TestCase): def test_big_input(self): owner = self.add_player() - loc = self.interface.add_location('143072699567177728', - 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT', 0, 0, 0) - loc_list = self.interface.find_location_by_owner(owner) - - self.assertEqual(loc_list[0].id, loc.id) + self.assertRaises(StringTooLong, self.interface.add_location, '143072699567177728', + 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT' + 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT' + 'TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT', 0, 0,) def test_duplicate_name(self): self.add_player() @@ -207,7 +220,7 @@ class TestGeoffreyDatabase(TestCase): self.interface.delete_location('143072699567177728', 'test') shops = self.interface.find_shop_selling_item('dirt') - self.assertGreater(len(shops), 0) + self.assertEqual(len(shops), 0)