diff --git a/database.py b/database.py index ce25864..9a20480 100644 --- a/database.py +++ b/database.py @@ -1,5 +1,6 @@ -import psycopg2 -import psycopg2.extras +import psycopg2 +from psycopg2.extensions import connection +from psycopg2.extras import DictCursor, DictRow from logging import Logger @@ -14,16 +15,16 @@ class Database: """Create a database file if not exists.""" self.log: Logger = log self.log.debug('Database.__init__(DSN=\'%s\')', dsn) - self.conn = psycopg2.connect(dsn) - self.cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) + self.conn: connection = psycopg2.connect(dsn) + self.cur: DictCursor = self.conn.cursor(cursor_factory=DictCursor) self.__init_schema() def add_user(self, telegram_id: int) -> int: """Add a user's telegram id to the database and return its database id.""" self.log.debug('add_user(telegram_id=\'%s\')', telegram_id) - self.cur.execute('INSERT INTO users (telegram_id) VALUES (%s)', [telegram_id]) + id = self.cur.execute('INSERT INTO users (telegram_id) VALUES (%s) RETURNING id', [telegram_id]) self.conn.commit() - return self.find_user(telegram_id) + return id def find_user(self, telegram_id: int) -> int | None: """Get a user's telegram id and return its database id.""" @@ -37,9 +38,9 @@ class Database: def add_feed(self, url: str) -> int: """Add a feed to the database and return its id.""" self.log.debug('add_feed(url=\'%s\')', url) - self.cur.execute('INSERT INTO feeds (url) VALUES (%s)', [url]) + id = self.cur.execute('INSERT INTO feeds (url) VALUES (%s) RETURNING id', [url]) self.conn.commit() - return self.find_feed_by_url(url) + return id def find_feed_by_url(self, url: str) -> int | None: """Find feed ID by url.""" @@ -126,14 +127,14 @@ class Database: self.cur.execute('SELECT * FROM feeds') return self.cur.fetchall() - def find_user_feeds(self, user_id: int) -> list[psycopg2.extras.DictRow]: + def find_user_feeds(self, user_id: int) -> list[DictRow]: """Return a list of feeds the user is subscribed to.""" self.log.debug('find_user_feeds(user_id=\'%s\')', user_id) self.cur.execute('SELECT * FROM feeds WHERE id IN (SELECT feed_id FROM subscriptions WHERE user_id = %s)', [user_id]) return self.cur.fetchall() - def find_feed_items(self, feed_id: int) -> list[psycopg2.extras.DictRow]: + def find_feed_items(self, feed_id: int) -> list[DictRow]: """Get last feed items.""" self.log.debug('find_feed_items(feed_id=\'%s\')', feed_id) self.cur.execute('SELECT * FROM feeds_last_items WHERE feed_id = %s', [feed_id]) @@ -166,19 +167,16 @@ class Database: self.cur.execute('CREATE TABLE IF NOT EXISTS feeds (id SERIAL PRIMARY KEY, url TEXT NOT NULL UNIQUE)') self.cur.execute( 'CREATE TABLE IF NOT EXISTS subscriptions (' - ' user_id INTEGER,' - ' feed_id INTEGER,' - ' UNIQUE (user_id, feed_id),' - ' FOREIGN KEY(user_id) REFERENCES users(id),' - ' FOREIGN KEY(feed_id) REFERENCES feeds(id)' + ' user_id INTEGER REFERENCES users,' + ' feed_id INTEGER REFERENCES feeds,' + ' UNIQUE (user_id, feed_id)' ')' ) self.cur.execute( 'CREATE TABLE IF NOT EXISTS feeds_last_items (' - ' feed_id INTEGER,' + ' feed_id INTEGER REFERENCES feeds ON DELETE CASCADE,' ' url TEXT NOT NULL,' ' title TEXT,' - ' description TEXT,' - ' FOREIGN KEY(feed_id) REFERENCES feeds(id)' + ' description TEXT' ')' )