Maigret bot support (custom progress function fixed) (#1898)

* Fixed progress/close functions
* Fixed tests: execution time increased with alive_progressbar
This commit is contained in:
Soxoj
2024-11-26 15:54:26 +01:00
committed by GitHub
parent 324c118530
commit ee25c61fc2
3 changed files with 49 additions and 16 deletions
+5 -3
View File
@@ -162,15 +162,16 @@ class AiodnsDomainResolver(CheckerBase):
self.resolver = aiodns.DNSResolver(loop=loop) self.resolver = aiodns.DNSResolver(loop=loop)
def prepare(self, url, headers=None, allow_redirects=True, timeout=0, method='get'): def prepare(self, url, headers=None, allow_redirects=True, timeout=0, method='get'):
return self.resolver.query(url, 'A') self.url = url
return None
async def check(self, future) -> Tuple[str, int, Optional[CheckError]]: async def check(self) -> Tuple[str, int, Optional[CheckError]]:
status = 404 status = 404
error = None error = None
text = '' text = ''
try: try:
res = await future res = await self.resolver.query(self.url, 'A')
text = str(res[0].host) text = str(res[0].host)
status = 200 status = 200
except aiodns.error.DNSError: except aiodns.error.DNSError:
@@ -530,6 +531,7 @@ def make_site_result(
# Store future request object in the results object # Store future request object in the results object
results_site["future"] = future results_site["future"] = future
results_site["checker"] = checker results_site["checker"] = checker
return results_site return results_site
+35 -4
View File
@@ -100,11 +100,26 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
self.workers_count = kwargs.get('in_parallel', 10) self.workers_count = kwargs.get('in_parallel', 10)
self.queue = asyncio.Queue(self.workers_count) self.queue = asyncio.Queue(self.workers_count)
self.timeout = kwargs.get('timeout') self.timeout = kwargs.get('timeout')
self.bar_update = None # Store the update function from alive_bar # a function to show updated progress, alive_bar by default
self.progress_func = kwargs.get('progress_func', None)
self.progress = None
async def increment_progress(self, count): async def increment_progress(self, count):
if self.bar_update: update_func = self.progress.update
self.bar_update(count)
if asyncio.iscoroutinefunction(update_func):
await update_func(count)
else:
update_func(count)
await asyncio.sleep(0)
async def stop_progress(self):
close_func = self.progress.close
if asyncio.iscoroutinefunction(close_func):
await close_func()
else:
close_func()
await asyncio.sleep(0) await asyncio.sleep(0)
async def worker(self): async def worker(self):
@@ -122,7 +137,10 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
result = kwargs.get('default') result = kwargs.get('default')
self.results.append(result) self.results.append(result)
if self.progress:
await self.increment_progress(1) await self.increment_progress(1)
self.queue.task_done() self.queue.task_done()
async def _run(self, queries: Iterable[QueryDraft]): async def _run(self, queries: Iterable[QueryDraft]):
@@ -132,9 +150,22 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
workers = [create_task_func()(self.worker()) for _ in range(min_workers)] workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
if self.progress_func:
self.progress = self.progress_func(total=len(queries_list))
for t in queries_list:
await self.queue.put(t)
await self.queue.join()
for w in workers:
w.cancel()
await self.stop_progress()
else:
# Initialize alive_progress bar # Initialize alive_progress bar
with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar: with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar:
self.bar_update = bar # `alive_bar` uses its instance to update progress self.update = bar # `alive_bar` uses its instance to update progress
for t in queries_list: for t in queries_list:
await self.queue.put(t) await self.queue.put(t)
+4 -4
View File
@@ -55,12 +55,12 @@ async def test_asyncio_progressbar_queue_executor():
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=2) executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=2)
assert await executor.run(tasks) == [0, 1, 3, 2, 4, 6, 7, 5, 9, 8] assert await executor.run(tasks) == [0, 1, 3, 2, 4, 6, 7, 5, 9, 8]
assert executor.execution_time > 0.5 assert executor.execution_time > 0.5
assert executor.execution_time < 0.6 assert executor.execution_time < 0.7
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=3) executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=3)
assert await executor.run(tasks) == [0, 3, 1, 4, 6, 2, 7, 9, 5, 8] assert await executor.run(tasks) == [0, 3, 1, 4, 6, 2, 7, 9, 5, 8]
assert executor.execution_time > 0.4 assert executor.execution_time > 0.4
assert executor.execution_time < 0.5 assert executor.execution_time < 0.6
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=5) executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=5)
assert await executor.run(tasks) in ( assert await executor.run(tasks) in (
@@ -68,9 +68,9 @@ async def test_asyncio_progressbar_queue_executor():
[0, 3, 6, 1, 4, 9, 7, 2, 5, 8], [0, 3, 6, 1, 4, 9, 7, 2, 5, 8],
) )
assert executor.execution_time > 0.3 assert executor.execution_time > 0.3
assert executor.execution_time < 0.4 assert executor.execution_time < 0.5
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=10) executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=10)
assert await executor.run(tasks) == [0, 3, 6, 9, 1, 4, 7, 2, 5, 8] assert await executor.run(tasks) == [0, 3, 6, 9, 1, 4, 7, 2, 5, 8]
assert executor.execution_time > 0.2 assert executor.execution_time > 0.2
assert executor.execution_time < 0.3 assert executor.execution_time < 0.4