mirror of
https://github.com/soxoj/maigret.git
synced 2026-05-06 14:08:59 +00:00
Maigret bot support (custom progress function fixed) (#1898)
* Fixed progress/close functions * Fixed tests: execution time increased with alive_progressbar
This commit is contained in:
+5
-3
@@ -162,15 +162,16 @@ class AiodnsDomainResolver(CheckerBase):
|
|||||||
self.resolver = aiodns.DNSResolver(loop=loop)
|
self.resolver = aiodns.DNSResolver(loop=loop)
|
||||||
|
|
||||||
def prepare(self, url, headers=None, allow_redirects=True, timeout=0, method='get'):
|
def prepare(self, url, headers=None, allow_redirects=True, timeout=0, method='get'):
|
||||||
return self.resolver.query(url, 'A')
|
self.url = url
|
||||||
|
return None
|
||||||
|
|
||||||
async def check(self, future) -> Tuple[str, int, Optional[CheckError]]:
|
async def check(self) -> Tuple[str, int, Optional[CheckError]]:
|
||||||
status = 404
|
status = 404
|
||||||
error = None
|
error = None
|
||||||
text = ''
|
text = ''
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = await future
|
res = await self.resolver.query(self.url, 'A')
|
||||||
text = str(res[0].host)
|
text = str(res[0].host)
|
||||||
status = 200
|
status = 200
|
||||||
except aiodns.error.DNSError:
|
except aiodns.error.DNSError:
|
||||||
@@ -530,6 +531,7 @@ def make_site_result(
|
|||||||
|
|
||||||
# Store future request object in the results object
|
# Store future request object in the results object
|
||||||
results_site["future"] = future
|
results_site["future"] = future
|
||||||
|
|
||||||
results_site["checker"] = checker
|
results_site["checker"] = checker
|
||||||
|
|
||||||
return results_site
|
return results_site
|
||||||
|
|||||||
+35
-4
@@ -100,11 +100,26 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
|||||||
self.workers_count = kwargs.get('in_parallel', 10)
|
self.workers_count = kwargs.get('in_parallel', 10)
|
||||||
self.queue = asyncio.Queue(self.workers_count)
|
self.queue = asyncio.Queue(self.workers_count)
|
||||||
self.timeout = kwargs.get('timeout')
|
self.timeout = kwargs.get('timeout')
|
||||||
self.bar_update = None # Store the update function from alive_bar
|
# a function to show updated progress, alive_bar by default
|
||||||
|
self.progress_func = kwargs.get('progress_func', None)
|
||||||
|
self.progress = None
|
||||||
|
|
||||||
async def increment_progress(self, count):
|
async def increment_progress(self, count):
|
||||||
if self.bar_update:
|
update_func = self.progress.update
|
||||||
self.bar_update(count)
|
|
||||||
|
if asyncio.iscoroutinefunction(update_func):
|
||||||
|
await update_func(count)
|
||||||
|
else:
|
||||||
|
update_func(count)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def stop_progress(self):
|
||||||
|
close_func = self.progress.close
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(close_func):
|
||||||
|
await close_func()
|
||||||
|
else:
|
||||||
|
close_func()
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def worker(self):
|
async def worker(self):
|
||||||
@@ -122,7 +137,10 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
|||||||
result = kwargs.get('default')
|
result = kwargs.get('default')
|
||||||
|
|
||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
|
if self.progress:
|
||||||
await self.increment_progress(1)
|
await self.increment_progress(1)
|
||||||
|
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
|
|
||||||
async def _run(self, queries: Iterable[QueryDraft]):
|
async def _run(self, queries: Iterable[QueryDraft]):
|
||||||
@@ -132,9 +150,22 @@ class AsyncioProgressbarQueueExecutor(AsyncExecutor):
|
|||||||
|
|
||||||
workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
|
workers = [create_task_func()(self.worker()) for _ in range(min_workers)]
|
||||||
|
|
||||||
|
if self.progress_func:
|
||||||
|
self.progress = self.progress_func(total=len(queries_list))
|
||||||
|
|
||||||
|
for t in queries_list:
|
||||||
|
await self.queue.put(t)
|
||||||
|
|
||||||
|
await self.queue.join()
|
||||||
|
|
||||||
|
for w in workers:
|
||||||
|
w.cancel()
|
||||||
|
|
||||||
|
await self.stop_progress()
|
||||||
|
else:
|
||||||
# Initialize alive_progress bar
|
# Initialize alive_progress bar
|
||||||
with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar:
|
with alive_bar(len(queries_list), title="Searching", force_tty=True) as bar:
|
||||||
self.bar_update = bar # `alive_bar` uses its instance to update progress
|
self.update = bar # `alive_bar` uses its instance to update progress
|
||||||
|
|
||||||
for t in queries_list:
|
for t in queries_list:
|
||||||
await self.queue.put(t)
|
await self.queue.put(t)
|
||||||
|
|||||||
@@ -55,12 +55,12 @@ async def test_asyncio_progressbar_queue_executor():
|
|||||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=2)
|
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=2)
|
||||||
assert await executor.run(tasks) == [0, 1, 3, 2, 4, 6, 7, 5, 9, 8]
|
assert await executor.run(tasks) == [0, 1, 3, 2, 4, 6, 7, 5, 9, 8]
|
||||||
assert executor.execution_time > 0.5
|
assert executor.execution_time > 0.5
|
||||||
assert executor.execution_time < 0.6
|
assert executor.execution_time < 0.7
|
||||||
|
|
||||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=3)
|
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=3)
|
||||||
assert await executor.run(tasks) == [0, 3, 1, 4, 6, 2, 7, 9, 5, 8]
|
assert await executor.run(tasks) == [0, 3, 1, 4, 6, 2, 7, 9, 5, 8]
|
||||||
assert executor.execution_time > 0.4
|
assert executor.execution_time > 0.4
|
||||||
assert executor.execution_time < 0.5
|
assert executor.execution_time < 0.6
|
||||||
|
|
||||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=5)
|
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=5)
|
||||||
assert await executor.run(tasks) in (
|
assert await executor.run(tasks) in (
|
||||||
@@ -68,9 +68,9 @@ async def test_asyncio_progressbar_queue_executor():
|
|||||||
[0, 3, 6, 1, 4, 9, 7, 2, 5, 8],
|
[0, 3, 6, 1, 4, 9, 7, 2, 5, 8],
|
||||||
)
|
)
|
||||||
assert executor.execution_time > 0.3
|
assert executor.execution_time > 0.3
|
||||||
assert executor.execution_time < 0.4
|
assert executor.execution_time < 0.5
|
||||||
|
|
||||||
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=10)
|
executor = AsyncioProgressbarQueueExecutor(logger=logger, in_parallel=10)
|
||||||
assert await executor.run(tasks) == [0, 3, 6, 9, 1, 4, 7, 2, 5, 8]
|
assert await executor.run(tasks) == [0, 3, 6, 9, 1, 4, 7, 2, 5, 8]
|
||||||
assert executor.execution_time > 0.2
|
assert executor.execution_time > 0.2
|
||||||
assert executor.execution_time < 0.3
|
assert executor.execution_time < 0.4
|
||||||
|
|||||||
Reference in New Issue
Block a user