add a row factory to the database module
This commit is contained in:
parent
94c1447093
commit
f02b8d85c6
21
database.py
21
database.py
|
@ -11,6 +11,7 @@ class Database:
|
|||
"""Create a database file if not exists."""
|
||||
# TODO: think about removing check_same_thread=False
|
||||
self.conn = sqlite3.connect(path, check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self.cur = self.conn.cursor()
|
||||
self.__init_schema()
|
||||
|
||||
|
@ -26,7 +27,7 @@ class Database:
|
|||
row = self.cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
return row['id']
|
||||
|
||||
def add_feed(self, url: str) -> int:
|
||||
"""Add a feed to the database and return its id."""
|
||||
|
@ -40,7 +41,7 @@ class Database:
|
|||
row = self.cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
return row['id']
|
||||
|
||||
def subscribe_user_by_url(self, user_id: int, url: str) -> None:
|
||||
"""Subscribe user to the feed creating it if does not exist yet."""
|
||||
|
@ -92,30 +93,28 @@ class Database:
|
|||
|
||||
def get_feed_subscribers_count(self, feed_id: int) -> int:
|
||||
"""Count feed subscribers."""
|
||||
self.cur.execute('SELECT COUNT(user_id) FROM subscriptions WHERE feed_id = ?', [feed_id])
|
||||
self.cur.execute('SELECT COUNT(user_id) AS amount_subscribers FROM subscriptions WHERE feed_id = ?', [feed_id])
|
||||
row = self.cur.fetchone()
|
||||
if row is None:
|
||||
return 0
|
||||
return int(row[0])
|
||||
return row['amount_subscribers']
|
||||
|
||||
def find_feed_subscribers(self, feed_id: int) -> list[int]:
|
||||
"""Return feed subscribers"""
|
||||
self.cur.execute('SELECT telegram_id FROM users WHERE id IN (SELECT user_id FROM subscriptions WHERE feed_id = ?)', [feed_id])
|
||||
subscribers = self.cur.fetchall()
|
||||
return list(map(lambda x: x[0], subscribers))
|
||||
return list(map(lambda x: x['telegram_id'], subscribers))
|
||||
|
||||
def find_feeds(self) -> list:
|
||||
def find_feeds(self) -> list[sqlite3.Row]:
|
||||
"""Get a list of feeds."""
|
||||
self.cur.execute('SELECT * FROM feeds')
|
||||
return self.cur.fetchall()
|
||||
|
||||
def find_user_feeds(self, user_id: int) -> list[tuple]:
|
||||
def find_user_feeds(self, user_id: int) -> list[sqlite3.Row]:
|
||||
"""Return a list of feeds the user is subscribed to."""
|
||||
self.cur.execute('SELECT * FROM feeds WHERE id IN (SELECT feed_id FROM subscriptions WHERE user_id = ?)',
|
||||
[user_id])
|
||||
return self.cur.fetchall()
|
||||
|
||||
def find_feed_items(self, feed_id: int) -> list[tuple]:
|
||||
def find_feed_items(self, feed_id: int) -> list[sqlite3.Row]:
|
||||
"""Get last feed items."""
|
||||
self.cur.execute('SELECT * FROM feeds_last_items WHERE feed_id = ?', [feed_id])
|
||||
return self.cur.fetchall()
|
||||
|
@ -125,7 +124,7 @@ class Database:
|
|||
items = self.find_feed_items(feed_id)
|
||||
if not items:
|
||||
return items
|
||||
return list(map(lambda x: x[1], items))
|
||||
return list(map(lambda x: x['url'], items))
|
||||
|
||||
def update_feed_items(self, feed_id: int, new_items: list[FeedItem]) -> None:
|
||||
"""Replace last feed items with a list items that receive."""
|
||||
|
|
Loading…
Reference in a new issue