mirror of
https://github.com/soxoj/maigret.git
synced 2026-05-06 14:08:59 +00:00
Maigret bot support (custom progress function fixed) (#1898)
* Fixed progress/close functions * Fixed tests: execution time increased with alive_progressbar
This commit is contained in:
+5
-3
@@ -162,15 +162,16 @@ class AiodnsDomainResolver(CheckerBase):
|
||||
self.resolver = aiodns.DNSResolver(loop=loop)
|
||||
|
||||
def prepare(self, url, headers=None, allow_redirects=True, timeout=0, method='get'):
|
||||
return self.resolver.query(url, 'A')
|
||||
self.url = url
|
||||
return None
|
||||
|
||||
async def check(self, future) -> Tuple[str, int, Optional[CheckError]]:
|
||||
async def check(self) -> Tuple[str, int, Optional[CheckError]]:
|
||||
status = 404
|
||||
error = None
|
||||
text = ''
|
||||
|
||||
try:
|
||||
res = await future
|
||||
res = await self.resolver.query(self.url, 'A')
|
||||
text = str(res[0].host)
|
||||
status = 200
|
||||
except aiodns.error.DNSError:
|
||||
@@ -530,6 +531,7 @@ def make_site_result(
|
||||
|
||||
# Store future request object in the results object
|
||||
results_site["future"] = future
|
||||
|
||||
results_site["checker"] = checker
|
||||
|
||||
return results_site
|
||||
|
||||
+35
-4
@@ -100,11 +100,26 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
||||
self.workers_count = kwargs.get('in_parallel', 10)
|
||||
self.queue = asyncio.Queue(self.workers_count)
|
||||
self.timeout = kwargs.get('timeout')
|
||||
self.bar_update = None # Store the update function from alive_bar
|
||||
# a function to show updated progress, alive_bar by default
|
||||
self.progress_func = kwargs.get('progress_func', None)
|
||||
self.progress = None
|
||||
|
||||
async def increment_progress(self, count):
|
||||
if self.bar_update:
|
||||
self.bar_update(count)
|
||||
update_func = self.progress.update
|
||||
|
||||
if asyncio.iscoroutinefunction(update_func):
|
||||
await update_func(count)
|
||||
else:
|
||||
update_func(count)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def stop_progress(self):
|
||||
close_func = self.progress.close
|
||||
|
||||
if asyncio.iscoroutinefunction(close_func):
|
||||
await close_func()
|
||||
else:
|
||||
close_func()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def worker(self):
|
||||
@@ -122,7 +137,10 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
||||
result = kwargs.get('default')
|
||||
|
||||
self.results.append(result)
|
||||
|
||||
if self.progress:
|
||||
await self.increment_progress(1)
|
||||
|
||||
self.queue.task_done()
|
||||
|
||||
async def _run(self, queries: Iterable[QueryDraft]):
|
||||
@@ -132,9 +150,22 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
||||
|
||||
workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
|
||||
|
||||
if self.progress_func:
|
||||
self.progress = self.progress_func(total=len(queries_list))
|
||||
|
||||
for t in queries_list:
|
||||
await self.queue.put(t)
|
||||
|
||||
await self.queue.join()
|
||||
|
||||
for w in workers:
|
||||
w.cancel()
|
||||
|
||||
await self.stop_progress()
|
||||
else:
|
||||
# Initialize alive_progress bar
|
||||
with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar:
|
||||
self.bar_update = bar # `alive_bar` uses its instance to update progress
|
||||
self.update = bar # `alive_bar` uses its instance to update progress
|
||||
|
||||
for t in queries_list:
|
||||
await self.queue.put(t)
|
||||
|
||||
@@ -55,12 +55,12 @@ async def test_asyncio_progressbar_queue_executor():
|
||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=2)
|
||||
assert await executor.run(tasks) == [0, 1, 3, 2, 4, 6, 7, 5, 9, 8]
|
||||
assert executor.execution_time > 0.5
|
||||
assert executor.execution_time < 0.6
|
||||
assert executor.execution_time < 0.7
|
||||
|
||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=3)
|
||||
assert await executor.run(tasks) == [0, 3, 1, 4, 6, 2, 7, 9, 5, 8]
|
||||
assert executor.execution_time > 0.4
|
||||
assert executor.execution_time < 0.5
|
||||
assert executor.execution_time < 0.6
|
||||
|
||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=5)
|
||||
assert await executor.run(tasks) in (
|
||||
@@ -68,9 +68,9 @@ async def test_asyncio_progressbar_queue_executor():
|
||||
[0, 3, 6, 1, 4, 9, 7, 2, 5, 8],
|
||||
)
|
||||
assert executor.execution_time > 0.3
|
||||
assert executor.execution_time < 0.4
|
||||
assert executor.execution_time < 0.5
|
||||
|
||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=10)
|
||||
assert await executor.run(tasks) == [0, 3, 6, 9, 1, 4, 7, 2, 5, 8]
|
||||
assert executor.execution_time > 0.2
|
||||
assert executor.execution_time < 0.3
|
||||
assert executor.execution_time < 0.4
|
||||
|
||||
Reference in New Issue
Block a user