Skip to content

Commit

Permalink
Merge pull request #211 from ninoseki/improve-concurrency
Browse files Browse the repository at this point in the history
refactor: improve concurrency
  • Loading branch information
ninoseki authored Mar 17, 2024
2 parents 1c930ed + 2b55e8a commit 997df76
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 79 deletions.
24 changes: 15 additions & 9 deletions backend/services/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
ares_query_cname_result,
ares_query_txt_result,
)
from returns.future import future_safe
from returns.unsafe import unsafe_perform_io

from backend import schemas
from backend.schemas.dns import AAAA, CNAME, TXT, A
Expand All @@ -20,17 +22,21 @@
QUERY_TYPES = typing.Literal["A", "AAAA", "CNAME", "TXT"]


async def query(name: str, query_type: QUERY_TYPES):
try:
resolver = aiodns.DNSResolver()
records = await resolver.query(name, query_type)
@future_safe
async def safe_query(
name: str, query_type: QUERY_TYPES, resolver: aiodns.DNSResolver | None = None
) -> list[typing.Any]:
resolver = resolver or aiodns.DNSResolver()
records = await resolver.query(name, query_type)

if not isinstance(records, list):
records = [records]
if not isinstance(records, list):
records = [records]

return typing.cast(list[typing.Any], records)
except aiodns.error.DNSError:
return []
return typing.cast(list[typing.Any], records)


async def query(name: str, query_type: QUERY_TYPES) -> list[typing.Any]:
return unsafe_perform_io((await safe_query(name, query_type)).value_or([]))


async def query_a_records(name: str) -> list[A]:
Expand Down
134 changes: 64 additions & 70 deletions backend/services/fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
from dataclasses import dataclass, field
from dataclasses import dataclass

import aiometer
import httpx
from returns.functions import raise_exception
from returns.future import FutureResultE, future_safe
Expand All @@ -24,102 +24,96 @@
class Container:
response: httpx.Response

html: schemas.HTML | None = field(default=None)
dns: schemas.DNS | None = field(default=None)
tracker: schemas.Tracker | None = field(default=None)
whois: schemas.Whois | None = field(default=None)
favicon: schemas.Favicon | None = field(default=None)
certificate: schemas.Certificate | None = field(default=None)
tls: schemas.TLS | None = field(default=None)


@future_safe
async def get_url(url: str) -> Container:
async def get_url(url: str) -> httpx.Response:
async with httpx.AsyncClient(verify=False) as client:
response = await client.get(url, follow_redirects=True)
return Container(response=response)
return await client.get(url, follow_redirects=True)


@future_safe
async def get_whois(container: Container) -> Container:
container.whois = await Whois().call(container.response)
return container
async def get_whois(res: httpx.Response) -> schemas.Whois:
return await Whois().call(res)


@future_safe
async def get_certificate(container: Container) -> Container:
with contextlib.suppress(Exception):
container.certificate = await Certificate().call(container.response)

return container
async def get_certificate(res: httpx.Response) -> schemas.Certificate | None:
return await Certificate().call(res)


@future_safe
async def get_favicon(container: Container) -> Container:
with contextlib.suppress(Exception):
container.favicon = await Favicon().call(container.response)

return container
async def get_favicon(res: httpx.Response) -> schemas.Favicon | None:
return await Favicon().call(res)


@future_safe
async def get_html(container: Container) -> Container:
with contextlib.suppress(Exception):
container.html = await HTML().call(container.response)

return container
async def get_html(res: httpx.Response) -> schemas.HTML:
return await HTML().call(res)


@future_safe
async def get_tracker(container: Container) -> Container:
try:
container.tracker = await Tracker().call(container.response)
except Exception:
container.tracker = schemas.Tracker()

return container
async def get_tracker(res: httpx.Response) -> schemas.Tracker:
return await Tracker().call(res)


@future_safe
async def get_dns(container: Container) -> Container:
try:
container.dns = await DNS().call(container.response)
except Exception:
container.dns = schemas.DNS()

return container
async def get_dns(res: httpx.Response) -> schemas.DNS:
return await DNS().call(res)


@future_safe
async def get_tls(container: Container) -> Container:
with contextlib.suppress(Exception):
container.tls = await TLS().call(container.response)
async def get_tls(res: httpx.Response) -> schemas.TLS | None:
return await TLS().call(res)


return container
@future_safe
async def get_fingerprint(res: httpx.Response) -> schemas.Fingerprint:
(
certificate_result,
dns_result,
favicon_result,
html_result,
tls_result,
tracker_result,
whois_result,
) = await aiometer.run_all(
[
get_certificate(res).awaitable,
get_dns(res).awaitable,
get_favicon(res).awaitable,
get_html(res).awaitable,
get_tls(res).awaitable,
get_tracker(res).awaitable,
get_whois(res).awaitable,
]
)
certificate: schemas.Certificate | None = unsafe_perform_io(
certificate_result.value_or(None) # type: ignore
) # type: ignore
dns: schemas.DNS = unsafe_perform_io(dns_result.value_or(schemas.DNS())) # type: ignore
favicon: schemas.Favicon = unsafe_perform_io(favicon_result.value_or(None)) # type: ignore
html: schemas.HTML = unsafe_perform_io(html_result.value_or(None)) # type: ignore
tls: schemas.TLS = unsafe_perform_io(tls_result.value_or(None)) # type: ignore
tracker: schemas.Tracker = unsafe_perform_io(
tracker_result.value_or(schemas.Tracker()) # type: ignore
)
whois: schemas.Whois = unsafe_perform_io(whois_result.value_or(schemas.Whois())) # type: ignore
return schemas.Fingerprint(
certificate=certificate,
dns=dns,
favicon=favicon,
html=html,
tls=tls,
tracker=tracker,
whois=whois,
headers=dict(res.headers),
)


class Fingerprint(AbstractService):
async def call(self, url: str) -> schemas.Fingerprint:
f_result: FutureResultE[Container] = flow(
url,
get_url,
bind(get_certificate),
bind(get_dns),
bind(get_favicon),
bind(get_html),
bind(get_tls),
bind(get_tracker),
bind(get_whois),
f_result: FutureResultE[schemas.Fingerprint] = flow(
url, get_url, bind(get_fingerprint)
)
result = (await f_result.awaitable()).alt(raise_exception).unwrap()
container = unsafe_perform_io(result)
return schemas.Fingerprint(
html=container.html, # type: ignore
dns=container.dns, # type: ignore
tracker=container.tracker, # type: ignore
whois=container.whois, # type: ignore
favicon=container.favicon,
certificate=container.certificate,
tls=container.tls,
headers=dict(container.response.headers),
)
return unsafe_perform_io(result)

0 comments on commit 997df76

Please sign in to comment.