Retries set to 0 by default, refactored code of executor with progress (#1899)

* Retries set to 0 by default, refactored code of executor with progress
This commit is contained in:
Soxoj
2024-11-26 19:07:15 +01:00
committed by GitHub
parent 80cf70d151
commit 8a98aa9eaa
3 changed files with 29 additions and 34 deletions
+27 -32
View File
@@ -100,29 +100,33 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
self.workers_count = kwargs.get('in_parallel', 10) self.workers_count = kwargs.get('in_parallel', 10)
self.queue = asyncio.Queue(self.workers_count) self.queue = asyncio.Queue(self.workers_count)
self.timeout = kwargs.get('timeout') self.timeout = kwargs.get('timeout')
# a function to show updated progress, alive_bar by default # Pass a progress function; alive_bar by default
self.progress_func = kwargs.get('progress_func', None) self.progress_func = kwargs.get('progress_func', alive_bar)
self.progress = None self.progress = None
# TODO: tests
async def increment_progress(self, count): async def increment_progress(self, count):
update_func = self.progress.update """Update progress by calling the provided progress function."""
if self.progress:
if asyncio.iscoroutinefunction(update_func): if asyncio.iscoroutinefunction(self.progress):
await update_func(count) await self.progress(count)
else: else:
update_func(count) self.progress(count)
await asyncio.sleep(0) await asyncio.sleep(0)
# TODO: tests
async def stop_progress(self): async def stop_progress(self):
close_func = self.progress.close """Stop the progress tracking."""
if hasattr(self.progress, "close") and self.progress:
if asyncio.iscoroutinefunction(close_func): close_func = self.progress.close
await close_func() if asyncio.iscoroutinefunction(close_func):
else: await close_func()
close_func() else:
await asyncio.sleep(0) close_func()
await asyncio.sleep(0)
async def worker(self): async def worker(self):
"""Consume tasks from the queue and process them."""
while True: while True:
try: try:
f, args, kwargs = self.queue.get_nowait() f, args, kwargs = self.queue.get_nowait()
@@ -144,34 +148,25 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
self.queue.task_done() self.queue.task_done()
async def _run(self, queries: Iterable[QueryDraft]): async def _run(self, queries: Iterable[QueryDraft]):
"""Main runner function to execute tasks with progress tracking."""
self.results: List[Any] = [] self.results: List[Any] = []
queries_list = list(queries) queries_list = list(queries)
min_workers = min(len(queries_list), self.workers_count) min_workers = min(len(queries_list), self.workers_count)
workers = [create_task_func()(self.worker()) for _ in range(min_workers)] workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
# Initialize the progress bar
if self.progress_func: if self.progress_func:
self.progress = self.progress_func(total=len(queries_list)) with self.progress_func(len(queries_list), title="Searching", force_tty=True) as bar:
self.progress = bar # Assign alive_bar's callable to self.progress
for t in queries_list:
await self.queue.put(t)
await self.queue.join()
for w in workers:
w.cancel()
await self.stop_progress()
else:
# Initialize alive_progress bar
with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar:
self.update = bar # `alive_bar` uses its instance to update progress
# Add tasks to the queue
for t in queries_list: for t in queries_list:
await self.queue.put(t) await self.queue.put(t)
# Wait for tasks to complete
await self.queue.join() await self.queue.join()
# Cancel any remaining workers
for w in workers: for w in workers:
w.cancel() w.cancel()
+1 -1
View File
@@ -24,7 +24,7 @@
"supposed_usernames": [ "supposed_usernames": [
"alex", "god", "admin", "red", "blue", "john" "alex", "god", "admin", "red", "blue", "john"
], ],
"retries_count": 1, "retries_count": 0,
"sites_db_path": "resources/data.json", "sites_db_path": "resources/data.json",
"timeout": 30, "timeout": 30,
"max_connections": 100, "max_connections": 100,
+1 -1
View File
@@ -28,7 +28,7 @@ DEFAULT_ARGS: Dict[str, Any] = {
'print_not_found': False, 'print_not_found': False,
'proxy': None, 'proxy': None,
'reports_sorting': 'default', 'reports_sorting': 'default',
'retries': 1, 'retries': 0,
'self_check': False, 'self_check': False,
'site_list': [], 'site_list': [],
'stats': False, 'stats': False,