diff --git a/database.py b/database.py index 4ea6a67..c584eb1 100644 --- a/database.py +++ b/database.py @@ -2,6 +2,7 @@ from logging import Logger import psycopg2 from psycopg2.extensions import connection from psycopg2.extras import DictCursor, DictRow +from yoyo import get_backend, read_migrations from exceptions import DisplayableException from rss import FeedItem @@ -10,12 +11,12 @@ class Database: """Implement interaction with the database.""" def __init__(self, dsn: str, log: Logger) -> None: - """Create a database file if not exists.""" + """Initialize the database""" self.log: Logger = log self.log.debug('Database.__init__(DSN=\'%s\')', dsn) self.conn: connection = psycopg2.connect(dsn) self.cur: DictCursor = self.conn.cursor(cursor_factory=DictCursor) - self.__init_schema() + self.__migrate(dsn) def add_user(self, telegram_id: int) -> int: """Add a user's telegram id to the database and return its database id.""" @@ -156,26 +157,14 @@ class Database: 'INSERT INTO feeds_last_items (feed_id, url, guid) VALUES (%s, %s, %s)', new_items) self.conn.commit() - def __init_schema(self) -> None: - self.log.debug('__init_schema()') - self.cur.execute( - 'CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, telegram_id INTEGER NOT NULL UNIQUE)' - ) - self.cur.execute('CREATE TABLE IF NOT EXISTS feeds (id SERIAL PRIMARY KEY, url TEXT NOT NULL UNIQUE)') - self.cur.execute( - 'CREATE TABLE IF NOT EXISTS subscriptions (' - ' user_id INTEGER REFERENCES users,' - ' feed_id INTEGER REFERENCES feeds,' - ' UNIQUE (user_id, feed_id)' - ')' - ) - self.cur.execute( - 'CREATE TABLE IF NOT EXISTS feeds_last_items (' - ' feed_id INTEGER REFERENCES feeds ON DELETE CASCADE,' - ' url TEXT NOT NULL,' - ' guid TEXT' - ')' - ) + def __migrate(self, dsn: str) -> None: + """Migrate or initialize the database schema""" + self.log.debug(f'Database.__migrate(dsn={dsn})') + backend = get_backend(dsn) + migrations = read_migrations('./migrations') + + with backend.lock(): + backend.apply_migrations(backend.to_apply(migrations)) @staticmethod def __dictrow_to_dict_list(rows: list[DictRow]) -> list[dict]: diff --git a/migrations/0000.initial_schema.py b/migrations/0000.initial_schema.py new file mode 100644 index 0000000..be7fa0a --- /dev/null +++ b/migrations/0000.initial_schema.py @@ -0,0 +1,30 @@ +from yoyo import step + +steps = [ + step( + 'CREATE TABLE users (' + ' id SERIAL PRIMARY KEY,' + ' telegram_id INTEGER NOT NULL UNIQUE' + ')' + ), + step( + 'CREATE TABLE feeds (' + ' id SERIAL PRIMARY KEY,' + ' url TEXT NOT NULL UNIQUE' + ')' + ), + step( + 'CREATE TABLE subscriptions (' + ' user_id INTEGER REFERENCES users,' + ' feed_id INTEGER REFERENCES feeds,' + ' UNIQUE (user_id, feed_id)' + ')' + ), + step( + 'CREATE TABLE feeds_last_items (' + ' feed_id INTEGER REFERENCES feeds ON DELETE CASCADE,' + ' url TEXT NOT NULL,' + ' guid TEXT' + ')' + ) +] diff --git a/requirements.txt b/requirements.txt index cc5e556..fd8e157 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ six==1.16.0 urllib3==1.26.9 validators==0.19.0 webencodings==0.5.1 +yoyo-migrations==7.3.2