add a row factory to the database module

This commit is contained in:
mitsuha_s 2022-06-12 20:08:11 +00:00
parent 94c1447093
commit f02b8d85c6
1 changed files with 10 additions and 11 deletions

View File

@ -11,6 +11,7 @@ class Database:
"""Create a database file if not exists."""
# TODO: think about removing check_same_thread=False
self.conn = sqlite3.connect(path, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self.cur = self.conn.cursor()
self.__init_schema()
@ -26,7 +27,7 @@ class Database:
row = self.cur.fetchone()
if row is None:
return None
return row[0]
return row['id']
def add_feed(self, url: str) -> int:
"""Add a feed to the database and return its id."""
@ -40,7 +41,7 @@ class Database:
row = self.cur.fetchone()
if row is None:
return None
return row[0]
return row['id']
def subscribe_user_by_url(self, user_id: int, url: str) -> None:
"""Subscribe user to the feed creating it if does not exist yet."""
@ -92,30 +93,28 @@ class Database:
def get_feed_subscribers_count(self, feed_id: int) -> int:
"""Count feed subscribers."""
self.cur.execute('SELECT COUNT(user_id) FROM subscriptions WHERE feed_id = ?', [feed_id])
self.cur.execute('SELECT COUNT(user_id) AS amount_subscribers FROM subscriptions WHERE feed_id = ?', [feed_id])
row = self.cur.fetchone()
if row is None:
return 0
return int(row[0])
return row['amount_subscribers']
def find_feed_subscribers(self, feed_id: int) -> list[int]:
"""Return feed subscribers"""
self.cur.execute('SELECT telegram_id FROM users WHERE id IN (SELECT user_id FROM subscriptions WHERE feed_id = ?)', [feed_id])
subscribers = self.cur.fetchall()
return list(map(lambda x: x[0], subscribers))
return list(map(lambda x: x['telegram_id'], subscribers))
def find_feeds(self) -> list:
def find_feeds(self) -> list[sqlite3.Row]:
"""Get a list of feeds."""
self.cur.execute('SELECT * FROM feeds')
return self.cur.fetchall()
def find_user_feeds(self, user_id: int) -> list[tuple]:
def find_user_feeds(self, user_id: int) -> list[sqlite3.Row]:
"""Return a list of feeds the user is subscribed to."""
self.cur.execute('SELECT * FROM feeds WHERE id IN (SELECT feed_id FROM subscriptions WHERE user_id = ?)',
[user_id])
return self.cur.fetchall()
def find_feed_items(self, feed_id: int) -> list[tuple]:
def find_feed_items(self, feed_id: int) -> list[sqlite3.Row]:
"""Get last feed items."""
self.cur.execute('SELECT * FROM feeds_last_items WHERE feed_id = ?', [feed_id])
return self.cur.fetchall()
@ -125,7 +124,7 @@ class Database:
items = self.find_feed_items(feed_id)
if not items:
return items
return list(map(lambda x: x[1], items))
return list(map(lambda x: x['url'], items))
def update_feed_items(self, feed_id: int, new_items: list[FeedItem]) -> None:
"""Replace last feed items with a list items that receive."""