mirror of
https://github.com/soxoj/maigret.git
synced 2026-05-07 06:24:35 +00:00
Retries set to 0 by default, refactored code of executor with progress (#1899)
* Retries set to 0 by default, refactored code of executor with progress
This commit is contained in:
+27
-32
@@ -100,29 +100,33 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
|||||||
self.workers_count = kwargs.get('in_parallel', 10)
|
self.workers_count = kwargs.get('in_parallel', 10)
|
||||||
self.queue = asyncio.Queue(self.workers_count)
|
self.queue = asyncio.Queue(self.workers_count)
|
||||||
self.timeout = kwargs.get('timeout')
|
self.timeout = kwargs.get('timeout')
|
||||||
# a function to show updated progress, alive_bar by default
|
# Pass a progress function; alive_bar by default
|
||||||
self.progress_func = kwargs.get('progress_func', None)
|
self.progress_func = kwargs.get('progress_func', alive_bar)
|
||||||
self.progress = None
|
self.progress = None
|
||||||
|
|
||||||
|
# TODO: tests
|
||||||
async def increment_progress(self, count):
|
async def increment_progress(self, count):
|
||||||
update_func = self.progress.update
|
"""Update progress by calling the provided progress function."""
|
||||||
|
if self.progress:
|
||||||
if asyncio.iscoroutinefunction(update_func):
|
if asyncio.iscoroutinefunction(self.progress):
|
||||||
await update_func(count)
|
await self.progress(count)
|
||||||
else:
|
else:
|
||||||
update_func(count)
|
self.progress(count)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# TODO: tests
|
||||||
async def stop_progress(self):
|
async def stop_progress(self):
|
||||||
close_func = self.progress.close
|
"""Stop the progress tracking."""
|
||||||
|
if hasattr(self.progress, "close") and self.progress:
|
||||||
if asyncio.iscoroutinefunction(close_func):
|
close_func = self.progress.close
|
||||||
await close_func()
|
if asyncio.iscoroutinefunction(close_func):
|
||||||
else:
|
await close_func()
|
||||||
close_func()
|
else:
|
||||||
await asyncio.sleep(0)
|
close_func()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def worker(self):
|
async def worker(self):
|
||||||
|
"""Consume tasks from the queue and process them."""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
f, args, kwargs = self.queue.get_nowait()
|
f, args, kwargs = self.queue.get_nowait()
|
||||||
@@ -144,34 +148,25 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
|||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
|
|
||||||
async def _run(self, queries: Iterable[QueryDraft]):
|
async def _run(self, queries: Iterable[QueryDraft]):
|
||||||
|
"""Main runner function to execute tasks with progress tracking."""
|
||||||
self.results: List[Any] = []
|
self.results: List[Any] = []
|
||||||
queries_list = list(queries)
|
queries_list = list(queries)
|
||||||
min_workers = min(len(queries_list), self.workers_count)
|
min_workers = min(len(queries_list), self.workers_count)
|
||||||
|
|
||||||
workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
|
workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
|
||||||
|
|
||||||
|
# Initialize the progress bar
|
||||||
if self.progress_func:
|
if self.progress_func:
|
||||||
self.progress = self.progress_func(total=len(queries_list))
|
with self.progress_func(len(queries_list), title="Searching", force_tty=True) as bar:
|
||||||
|
self.progress = bar # Assign alive_bar's callable to self.progress
|
||||||
for t in queries_list:
|
|
||||||
await self.queue.put(t)
|
|
||||||
|
|
||||||
await self.queue.join()
|
|
||||||
|
|
||||||
for w in workers:
|
|
||||||
w.cancel()
|
|
||||||
|
|
||||||
await self.stop_progress()
|
|
||||||
else:
|
|
||||||
# Initialize alive_progress bar
|
|
||||||
with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar:
|
|
||||||
self.update = bar # `alive_bar` uses its instance to update progress
|
|
||||||
|
|
||||||
|
# Add tasks to the queue
|
||||||
for t in queries_list:
|
for t in queries_list:
|
||||||
await self.queue.put(t)
|
await self.queue.put(t)
|
||||||
|
|
||||||
|
# Wait for tasks to complete
|
||||||
await self.queue.join()
|
await self.queue.join()
|
||||||
|
|
||||||
|
# Cancel any remaining workers
|
||||||
for w in workers:
|
for w in workers:
|
||||||
w.cancel()
|
w.cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
"supposed_usernames": [
|
"supposed_usernames": [
|
||||||
"alex", "god", "admin", "red", "blue", "john"
|
"alex", "god", "admin", "red", "blue", "john"
|
||||||
],
|
],
|
||||||
"retries_count": 1,
|
"retries_count": 0,
|
||||||
"sites_db_path": "resources/data.json",
|
"sites_db_path": "resources/data.json",
|
||||||
"timeout": 30,
|
"timeout": 30,
|
||||||
"max_connections": 100,
|
"max_connections": 100,
|
||||||
|
|||||||
+1
-1
@@ -28,7 +28,7 @@ DEFAULT_ARGS: Dict[str, Any] = {
|
|||||||
'print_not_found': False,
|
'print_not_found': False,
|
||||||
'proxy': None,
|
'proxy': None,
|
||||||
'reports_sorting': 'default',
|
'reports_sorting': 'default',
|
||||||
'retries': 1,
|
'retries': 0,
|
||||||
'self_check': False,
|
'self_check': False,
|
||||||
'site_list': [],
|
'site_list': [],
|
||||||
'stats': False,
|
'stats': False,
|
||||||
|
|||||||
Reference in New Issue
Block a user