From f02b8d85c67f050c533c9081b8fbbda2852138bc Mon Sep 17 00:00:00 2001 From: mitsuha_s Date: Sun, 12 Jun 2022 20:08:11 +0000 Subject: [PATCH] add a row factory to the database module --- database.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/database.py b/database.py index 66b45e8..040ee7b 100644 --- a/database.py +++ b/database.py @@ -11,6 +11,7 @@ class Database: """Create a database file if not exists.""" # TODO: think about removing check_same_thread=False self.conn = sqlite3.connect(path, check_same_thread=False) + self.conn.row_factory = sqlite3.Row self.cur = self.conn.cursor() self.__init_schema() @@ -26,7 +27,7 @@ class Database: row = self.cur.fetchone() if row is None: return None - return row[0] + return row['id'] def add_feed(self, url: str) -> int: """Add a feed to the database and return its id.""" @@ -40,7 +41,7 @@ class Database: row = self.cur.fetchone() if row is None: return None - return row[0] + return row['id'] def subscribe_user_by_url(self, user_id: int, url: str) -> None: """Subscribe user to the feed creating it if does not exist yet.""" @@ -92,30 +93,28 @@ class Database: def get_feed_subscribers_count(self, feed_id: int) -> int: """Count feed subscribers.""" - self.cur.execute('SELECT COUNT(user_id) FROM subscriptions WHERE feed_id = ?', [feed_id]) + self.cur.execute('SELECT COUNT(user_id) AS amount_subscribers FROM subscriptions WHERE feed_id = ?', [feed_id]) row = self.cur.fetchone() - if row is None: - return 0 - return int(row[0]) + return row['amount_subscribers'] def find_feed_subscribers(self, feed_id: int) -> list[int]: """Return feed subscribers""" self.cur.execute('SELECT telegram_id FROM users WHERE id IN (SELECT user_id FROM subscriptions WHERE feed_id = ?)', [feed_id]) subscribers = self.cur.fetchall() - return list(map(lambda x: x[0], subscribers)) + return list(map(lambda x: x['telegram_id'], subscribers)) - def find_feeds(self) -> list: + def find_feeds(self) -> list[sqlite3.Row]: """Get a list of feeds.""" self.cur.execute('SELECT * FROM feeds') return self.cur.fetchall() - def find_user_feeds(self, user_id: int) -> list[tuple]: + def find_user_feeds(self, user_id: int) -> list[sqlite3.Row]: """Return a list of feeds the user is subscribed to.""" self.cur.execute('SELECT * FROM feeds WHERE id IN (SELECT feed_id FROM subscriptions WHERE user_id = ?)', [user_id]) return self.cur.fetchall() - def find_feed_items(self, feed_id: int) -> list[tuple]: + def find_feed_items(self, feed_id: int) -> list[sqlite3.Row]: """Get last feed items.""" self.cur.execute('SELECT * FROM feeds_last_items WHERE feed_id = ?', [feed_id]) return self.cur.fetchall() @@ -125,7 +124,7 @@ class Database: items = self.find_feed_items(feed_id) if not items: return items - return list(map(lambda x: x[1], items)) + return list(map(lambda x: x['url'], items)) def update_feed_items(self, feed_id: int, new_items: list[FeedItem]) -> None: """Replace last feed items with a list items that receive."""