Commit 87d5630d authored by Bruno Barcarol Guimarães's avatar Bruno Barcarol Guimarães
Browse files

Add type annotations

parent 330bfab0
......@@ -6,12 +6,13 @@ import operator
import os
import sqlite3
import sys
import typing
import urllib.parse
import xml.etree.ElementTree
import youtube_dl
def main(argv):
def main(argv: typing.Sequence[str]):
args = parse_args(argv)
args.file = args.file or db_file()
with sqlite3.connect(args.file) as conn:
......@@ -23,7 +24,7 @@ def main(argv):
cmd(subs, **args_d)
def parse_args(argv):
def parse_args(argv: typing.Sequence[str]):
parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter,
description='Manage Youtube subscriptions',
......@@ -110,18 +111,18 @@ class Client(object):
def __init__(self, verbose):
self.debug = (lambda *_: None) if verbose < 2 else print
def __init__(self, verbose, ydl=None):
def __init__(self, verbose: int, ydl=None):
self._opts = {'logger': self.logger(verbose), 'extract_flat': True}
self._ydl = ydl or youtube_dl.YoutubeDL(self._opts)
def info(self, url):
def info(self, url: str):
return self._ydl.extract_info(url, download=False)
def channel_entries(self, yt_id):
def channel_entries(self, yt_id: str):
url = self.info(self.CHANNEL_URL.format(yt_id))['url']
return self.info(url)['entries']
def upload_date(self, yt_id):
def upload_date(self, yt_id: str):
d = self.info(self.VIDEO_URL.format(yt_id))['upload_date']
ts = datetime.datetime.strptime(d, '%Y%m%d').timestamp()
return int(ts)
......@@ -129,29 +130,29 @@ class Client(object):
class Query(object):
@classmethod
def make_args(cls, n): return ', '.join(('?',) * n)
def make_args(cls, n: int): return ', '.join(('?',) * n)
def __init__(self, table):
def __init__(self, table: str):
self.table = table
self.fields = []
self.joins = []
self.wheres = []
self.group_bys = []
self.order_bys = []
self.fields: typing.List[str] = []
self.joins: typing.List[str] = []
self.wheres: typing.List[str] = []
self.group_bys: typing.List[str] = []
self.order_bys: typing.List[str] = []
def add_fields(self, *name):
def add_fields(self, *name: str):
self.fields.extend(name)
def add_joins(self, *joins):
def add_joins(self, *joins: str):
self.joins.extend(joins)
def add_filter(self, *exprs):
def add_filter(self, *exprs: str):
self.wheres.extend(exprs)
def add_group(self, *exprs):
def add_group(self, *exprs: str):
self.group_bys.extend(exprs)
def add_order(self, *exprs):
def add_order(self, *exprs: str):
self.order_bys.extend(exprs)
def _query(self, l):
......@@ -180,14 +181,16 @@ class Query(object):
class Subscriptions(object):
def __init__(self, verbose, conn, now=None):
def __init__(
self, verbose: int, conn: sqlite3.Connection,
now: typing.Callable[[], datetime.datetime]=None):
self._verbose = verbose
self._conn = conn
self._now = now or datetime.datetime.now
def _log(self, *msg): self._verbose and print(*msg)
def raw(self, l):
def raw(self, l: typing.Iterable):
c = self._conn.cursor()
return [c.execute(x).fetchall() for x in l]
......@@ -214,7 +217,9 @@ class Subscriptions(object):
' watched boolean not null default(0),'
' foreign key(sub) references subs(id))')
def list(self, ids, show_id, show_ids, unwatched):
def list(
self, ids: typing.Sequence[str],
show_id: bool, show_ids: bool, unwatched: bool):
assert(bool(show_id) + bool(show_ids) <= 1)
q = Query(table='subs')
q.add_order('subs.id')
......@@ -235,7 +240,9 @@ class Subscriptions(object):
for _ in map(print, map(' '.join, c)): pass
def list_videos(
self, subscriptions, n, by_name, url, flat, show_ids, watched):
self, subscriptions: typing.Collection[str], n: int,
by_name: bool, url: bool, flat: bool,
show_ids: bool, watched: bool):
q = Query('videos')
q.add_joins('join subs on subs.id == videos.sub')
q.add_fields('subs.id', 'subs.name', 'videos.watched')
......@@ -244,7 +251,7 @@ class Subscriptions(object):
q.add_fields('videos.yt_id')
if not url:
q.add_fields('videos.title')
args = []
args: typing.List[typing.Union[str, int]] = []
if subscriptions:
if by_name:
q.add_filter('({})'.format(
......@@ -287,7 +294,7 @@ class Subscriptions(object):
total[1] += n
print(*fmt(*total), 'total')
def import_xml(self, path):
def import_xml(self, path: str):
tree = xml.etree.ElementTree.ElementTree(file=path)
subs = tree.find('.//*[@title="YouTube Subscriptions"]')
skip = lambda e, msg: print(
......@@ -317,17 +324,17 @@ class Subscriptions(object):
self._log('adding subscriptions:', [x[0] for x in add])
c.executemany('insert into subs (name, yt_id) values (?, ?)', add)
def _sub_exists(self, c, yt_id):
def _sub_exists(self, c: sqlite3.Cursor, yt_id: str):
return bool(c
.execute('select 1 from subs where yt_id == ?', (yt_id,))
.fetchall())
def add(self, yt_id, name):
def add(self, yt_id: str, name: str):
self._conn.cursor().execute(
'insert into subs (yt_id, name) values (?, ?)',
(yt_id, name))
def update(self, items, cache, client=None):
def update(self, items: typing.Collection[str], cache: int, client=None):
now = int(self._now().timestamp())
cache = now - (cache if cache is not None else 24 * 60 * 60)
c = self._conn.cursor()
......@@ -356,22 +363,26 @@ class Subscriptions(object):
for _ in c: pass
self._log(f'{count() - initial_count} new videos added after @{cache}')
def _video_exists(self, c, yt_id):
def _video_exists(self, c: sqlite3.Cursor, yt_id: str):
return bool(c
.execute('select 1 from videos where yt_id == ?', (yt_id,))
.fetchall())
def _add_video(self, c, sub_id, yt_id, title):
def _add_video(
self, c: sqlite3.Cursor, sub_id: str, yt_id: str, title: str):
c.execute(
'insert into videos (sub, yt_id, title) values (?, ?, ?)',
(sub_id, yt_id, title))
def _update_sub(self, c, sub_id, last_update):
def _update_sub(self, c: sqlite3.Cursor, sub_id: str, last_update: int):
c.execute(
'update subs set last_update = ? where id == ?',
(last_update, sub_id))
def watched(self, items, subs, oldest, older_than, url, remove):
def watched(
self, items: typing.Collection[str],
subs: bool, oldest: bool, older_than: bool, url: bool,
remove: bool=False):
assert(bool(subs) + bool(oldest) + bool(older_than) <= 1)
q = Query('videos')
if subs:
......@@ -397,6 +408,7 @@ class Subscriptions(object):
else:
q.add_filter('yt_id in ({})'.format(Query.make_args(len(items))))
c = self._conn.cursor()
ids: typing.Collection[typing.Union[str, int]]
if url:
q.add_fields('id, yt_id')
items = c .execute(q.query(), items).fetchall()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment