diff --git a/maigret/checking.py b/maigret/checking.py index f147a94..4aeab39 100644 --- a/maigret/checking.py +++ b/maigret/checking.py @@ -162,15 +162,16 @@ class AiodnsDomainResolver(CheckerBase): self.resolver = aiodns.DNSResolver(loop=loop) def prepare(self, url, headers=None, allow_redirects=True, timeout=0, method='get'): - return self.resolver.query(url, 'A') + self.url = url + return None - async def check(self, future) -> Tuple[str, int, Optional[CheckError]]: + async def check(self) -> Tuple[str, int, Optional[CheckError]]: status = 404 error = None text = '' try: - res = await future + res = await self.resolver.query(self.url, 'A') text = str(res[0].host) status = 200 except aiodns.error.DNSError: @@ -530,7 +531,8 @@ def make_site_result( # Store future request object in the results object results_site["future"] = future - results_site["checker"] = checker + + results_site["checker"] = checker return results_site diff --git a/maigret/executors.py b/maigret/executors.py index 2ecc32f..7b8e173 100644 --- a/maigret/executors.py +++ b/maigret/executors.py @@ -100,12 +100,27 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor): self.workers_count = kwargs.get('in_parallel', 10) self.queue = asyncio.Queue(self.workers_count) self.timeout = kwargs.get('timeout') - self.bar_update = None # Store the update function from alive_bar + # a function to show updated progress, alive_bar by default + self.progress_func = kwargs.get('progress_func', None) + self.progress = None async def increment_progress(self, count): - if self.bar_update: - self.bar_update(count) - await asyncio.sleep(0) + update_func = self.progress.update + + if asyncio.iscoroutinefunction(update_func): + await update_func(count) + else: + update_func(count) + await asyncio.sleep(0) + + async def stop_progress(self): + close_func = self.progress.close + + if asyncio.iscoroutinefunction(close_func): + await close_func() + else: + close_func() + await asyncio.sleep(0) async def worker(self): while True: @@ -122,7 +137,10 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor): result = kwargs.get('default') self.results.append(result) - await self.increment_progress(1) + + if self.progress: + await self.increment_progress(1) + self.queue.task_done() async def _run(self, queries: Iterable[QueryDraft]): @@ -132,9 +150,8 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor): workers = [create_task_func()(self.worker()) for _ in range(min_workers)] - # Initialize alive_progress bar - with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar: - self.bar_update = bar # `alive_bar` uses its instance to update progress + if self.progress_func: + self.progress = self.progress_func(total=len(queries_list)) for t in queries_list: await self.queue.put(t) @@ -144,4 +161,18 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor): for w in workers: w.cancel() + await self.stop_progress() + else: + # Initialize alive_progress bar + with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar: + self.update = bar # `alive_bar` uses its instance to update progress + + for t in queries_list: + await self.queue.put(t) + + await self.queue.join() + + for w in workers: + w.cancel() + return self.results \ No newline at end of file diff --git a/tests/test_executors.py b/tests/test_executors.py index 1b25df2..4cb4b98 100644 --- a/tests/test_executors.py +++ b/tests/test_executors.py @@ -55,12 +55,12 @@ async def test_asyncio_progressbar_queue_executor(): executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=2) assert await executor.run(tasks) == [0, 1, 3, 2, 4, 6, 7, 5, 9, 8] assert executor.execution_time > 0.5 - assert executor.execution_time < 0.6 + assert executor.execution_time < 0.7 executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=3) assert await executor.run(tasks) == [0, 3, 1, 4, 6, 2, 7, 9, 5, 8] assert executor.execution_time > 0.4 - assert executor.execution_time < 0.5 + assert executor.execution_time < 0.6 executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=5) assert await executor.run(tasks) in ( @@ -68,9 +68,9 @@ async def test_asyncio_progressbar_queue_executor(): [0, 3, 6, 1, 4, 9, 7, 2, 5, 8], ) assert executor.execution_time > 0.3 - assert executor.execution_time < 0.4 + assert executor.execution_time < 0.5 executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=10) assert await executor.run(tasks) == [0, 3, 6, 9, 1, 4, 7, 2, 5, 8] assert executor.execution_time > 0.2 - assert executor.execution_time < 0.3 + assert executor.execution_time < 0.4