mirror of
https://github.com/soxoj/maigret.git
synced 2026-05-07 06:24:35 +00:00
Improved usability of external progressbar func (#476)
This commit is contained in:
+4
-1
@@ -580,6 +580,8 @@ async def maigret(
|
||||
cookies=None,
|
||||
retries=0,
|
||||
check_domains=False,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> QueryResultWrapper:
|
||||
"""Main search func
|
||||
|
||||
@@ -660,7 +662,8 @@ async def maigret(
|
||||
executor = AsyncioSimpleExecutor(logger=logger)
|
||||
else:
|
||||
executor = AsyncioProgressbarQueueExecutor(
|
||||
logger=logger, in_parallel=max_connections, timeout=timeout + 0.5
|
||||
logger=logger, in_parallel=max_connections, timeout=timeout + 0.5,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
# make options objects for all the requests
|
||||
|
||||
+22
-2
@@ -81,6 +81,22 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
||||
self.queue = asyncio.Queue(self.workers_count)
|
||||
self.timeout = kwargs.get('timeout')
|
||||
|
||||
async def increment_progress(self, count):
|
||||
update_func = self.progress.update
|
||||
if asyncio.iscoroutinefunction(update_func):
|
||||
await update_func(count)
|
||||
else:
|
||||
update_func(count)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def stop_progress(self):
|
||||
stop_func = self.progress.close
|
||||
if asyncio.iscoroutinefunction(stop_func):
|
||||
await stop_func()
|
||||
else:
|
||||
stop_func()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def worker(self):
|
||||
while True:
|
||||
try:
|
||||
@@ -96,7 +112,7 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
||||
result = kwargs.get('default')
|
||||
|
||||
self.results.append(result)
|
||||
self.progress.update(1)
|
||||
await self.increment_progress(1)
|
||||
self.queue.task_done()
|
||||
|
||||
async def _run(self, queries: Iterable[QueryDraft]):
|
||||
@@ -109,10 +125,14 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
||||
workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
|
||||
|
||||
self.progress = self.progress_func(total=len(queries_list))
|
||||
|
||||
for t in queries_list:
|
||||
await self.queue.put(t)
|
||||
|
||||
await self.queue.join()
|
||||
|
||||
for w in workers:
|
||||
w.cancel()
|
||||
self.progress.close()
|
||||
|
||||
await self.stop_progress()
|
||||
return self.results
|
||||
|
||||
Reference in New Issue
Block a user