diff --git a/.black.toml b/.black.toml new file mode 100644 index 0000000..474d0af --- /dev/null +++ b/.black.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 120 +target-version = ['py311'] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8f37e23..9001f5d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,43 +6,53 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: '3.x' - - run: | - pip install flake8 - flake8 aioftp tests + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: '3.11' + - name: black + run: | + pip install black + black --config .black.toml --diff --check . + - name: ruff + run: | + pip install ruff + ruff check . + - name: mypy + run: | + pip install mypy + mypy tests: needs: lint runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ['3.11'] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - run: | - pip install -e ./[tests] - pytest - - uses: codecov/codecov-action@v2 - if: github.ref == 'refs/heads/master' - with: - fail_ci_if_error: true - verbose: true + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - run: | + pip install -e ./[dev] + pytest + - uses: codecov/codecov-action@v2 + if: github.ref == 'refs/heads/master' + with: + fail_ci_if_error: true + verbose: true + token: ${{ secrets.CODECOV_TOKEN }} deploy: needs: tests runs-on: ubuntu-latest if: github.ref == 'refs/heads/master' steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - uses: casperdcl/deploy-pypi@v2 - with: - password: ${{ secrets.PYPI_TOKEN }} - build: true - skip_existing: true + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: casperdcl/deploy-pypi@v2 + with: + password: ${{ secrets.PYPI_TOKEN }} + build: true + skip_existing: true diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000..e107134 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,16 @@ +[mypy] +files = aioftp +strict = true +ignore_missing_imports = true +allow_subclassing_any = true +allow_untyped_calls = true +pretty = true +show_error_codes = true +implicit_reexport = true +allow_untyped_decorators = true +warn_unused_ignores = false +warn_return_any = false +namespace_packages = true + +[mypy-tests.*] +ignore_errors = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..23a7a87 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,48 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks + +.python-linters: &python-linters + pass_filenames: false + language: system + types: [python] + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.4.0 + hooks: + - id: check-ast + - id: trailing-whitespace + - id: end-of-file-fixer + +- repo: https://github.com/asottile/add-trailing-comma + rev: v2.1.0 + hooks: + - id: add-trailing-comma + +- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.1.0 + hooks: + - id: pretty-format-yaml + args: + - --autofix + - --preserve-quotes + - --indent=2 + +- repo: local + hooks: + - <<: *python-linters + id: black + name: Format with Black + entry: black + args: ["--config", ".black.toml", "."] + + - <<: *python-linters + id: ruff + name: Check with ruff + entry: ruff + args: ["check", "--fix", "."] + + - <<: *python-linters + id: mypy + name: Validate types with MyPy + entry: mypy diff --git a/pytest.ini b/.pytest.ini similarity index 100% rename from pytest.ini rename to .pytest.ini diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 0000000..b37fd86 --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,4 @@ +line-length = 120 +target-version = "py311" +select = ["E", "W", "F", "Q", "UP", "I", "ASYNC"] +exclude = ["docs"] diff --git a/README.rst b/README.rst index c56e91a..80fff59 100644 --- a/README.rst +++ b/README.rst @@ -100,12 +100,14 @@ pyftpdlib 1.5.2 Dependencies ------------ -- Python 3.7+ +- Python 3.11+ 0.13.0 is the last version which supports python 3.5.3+ 0.16.1 is the last version which supports python 3.6+ +0.21.4 is the last version which supports python 3.7-3.10 + License ------- diff --git a/aioftp/__init__.py b/aioftp/__init__.py index 10a5cb0..2a0c7dd 100644 --- a/aioftp/__init__.py +++ b/aioftp/__init__.py @@ -1,20 +1,19 @@ """ftp client/server for asyncio""" # flake8: noqa -from .client import * -from .common import * -from .errors import * -from .pathio import * -from .server import * +from aioftp.client import * +from aioftp.common import * +from aioftp.errors import * +from aioftp.pathio import * +from aioftp.server import * +from aioftp.client import __all__ as client_all +from aioftp.server import __all__ as server_all +from aioftp.errors import __all__ as errors_all +from aioftp.common import __all__ as common_all +from aioftp.pathio import __all__ as pathio_all + __version__ = "0.21.4" version = tuple(map(int, __version__.split("."))) -__all__ = ( - client.__all__ + - server.__all__ + - errors.__all__ + - common.__all__ + - pathio.__all__ + - ("version", "__version__") -) +__all__ = client_all + server_all + errors_all + common_all + pathio_all + ("version", "__version__") diff --git a/aioftp/__main__.py b/aioftp/__main__.py index 8ee10db..d7ba18c 100644 --- a/aioftp/__main__.py +++ b/aioftp/__main__.py @@ -1,4 +1,3 @@ -"""Simple aioftp-based server with one user (anonymous or not)""" import argparse import asyncio import contextlib @@ -10,24 +9,36 @@ parser = argparse.ArgumentParser( prog="aioftp", usage="%(prog)s [options]", - description="Simple aioftp-based server with one user (anonymous or not)." + description="Simple aioftp-based server with one user (anonymous or not).", +) +parser.add_argument("--user", metavar="LOGIN", dest="login", help="user name to login") +parser.add_argument("--pass", metavar="PASSWORD", dest="password", help="password to login") +parser.add_argument( + "-d", + metavar="DIRECTORY", + dest="home", + help="the directory to share (default current directory)", +) +parser.add_argument( + "-q", + "--quiet", + action="store_true", + help="set logging level to 'ERROR' instead of 'INFO'", ) -parser.add_argument("--user", metavar="LOGIN", dest="login", - help="user name to login") -parser.add_argument("--pass", metavar="PASSWORD", dest="password", - help="password to login") -parser.add_argument("-d", metavar="DIRECTORY", dest="home", - help="the directory to share (default current directory)") -parser.add_argument("-q", "--quiet", action="store_true", - help="set logging level to 'ERROR' instead of 'INFO'") parser.add_argument("--memory", action="store_true", help="use memory storage") -parser.add_argument("--host", default=None, - help="host for binding [default: %(default)s]") -parser.add_argument("--port", type=int, default=2121, - help="port for binding [default: %(default)s]") -parser.add_argument("--family", choices=("ipv4", "ipv6", "auto"), - default="auto", - help="Socket family [default: %(default)s]") +parser.add_argument("--host", default=None, help="host for binding [default: %(default)s]") +parser.add_argument( + "--port", + type=int, + default=2121, + help="port for binding [default: %(default)s]", +) +parser.add_argument( + "--family", + choices=("ipv4", "ipv6", "auto"), + default="auto", + help="Socket family [default: %(default)s]", +) args = parser.parse_args() print(f"aioftp v{aioftp.__version__}") @@ -40,7 +51,7 @@ ) if args.memory: user = aioftp.User(args.login, args.password, base_path="/") - path_io_factory = aioftp.MemoryPathIO + path_io_factory: type[aioftp.MemoryPathIO] | type[aioftp.PathIO] = aioftp.MemoryPathIO else: if args.home: user = aioftp.User(args.login, args.password, base_path=args.home) diff --git a/aioftp/client.py b/aioftp/client.py index 0c3ec85..4414a4f 100644 --- a/aioftp/client.py +++ b/aioftp/client.py @@ -6,10 +6,14 @@ import logging import pathlib import re +from collections.abc import AsyncIterable, Awaitable, Callable from functools import partial +from ssl import SSLContext +from types import TracebackType +from typing import Any, Literal, Self, TypeVar -from . import errors, pathio -from .common import ( +from aioftp import errors, pathio +from aioftp.common import ( DEFAULT_ACCOUNT, DEFAULT_BLOCK_SIZE, DEFAULT_PASSWORD, @@ -27,7 +31,7 @@ ) try: - from siosocks.io.asyncio import open_connection + from siosocks.io.asyncio import open_connection # type: ignore except ImportError: from asyncio import open_connection @@ -35,8 +39,8 @@ __all__ = ( "BaseClient", "Client", - "DataConnectionThrottleStreamIO", "Code", + "DataConnectionThrottleStreamIO", ) logger = logging.getLogger(__name__) @@ -45,7 +49,8 @@ class Code(str): """ Representation of server status code. """ - def matches(self, mask): + + def matches(self, mask: str) -> bool: """ :param mask: Template for comparision. If mask symbol is not digit then it passes. @@ -53,92 +58,50 @@ def matches(self, mask): :: - >>> Code("123").matches("1") + >>> Code('123').matches('1') True - >>> Code("123").matches("1x3") + >>> Code('123').matches('1x3') True """ return all(map(lambda m, c: not m.isdigit() or m == c, mask, self)) -class DataConnectionThrottleStreamIO(ThrottleStreamIO): - """ - Add `finish` method to :py:class:`aioftp.ThrottleStreamIO`, which is - specific for data connection. This requires `client`. - - :param client: client class, which have :py:meth:`aioftp.Client.command` - :type client: :py:class:`aioftp.BaseClient` - - :param *args: positional arguments passed to - :py:class:`aioftp.ThrottleStreamIO` - - :param **kwargs: keyword arguments passed to - :py:class:`aioftp.ThrottleStreamIO` - """ - def __init__(self, client, *args, **kwargs): - super().__init__(*args, **kwargs) - self.client = client - - async def finish(self, expected_codes="2xx", wait_codes="1xx"): - """ - :py:func:`asyncio.coroutine` - - Close connection and wait for `expected_codes` response from server - passing `wait_codes`. - - :param expected_codes: tuple of expected codes or expected code - :type expected_codes: :py:class:`tuple` of :py:class:`str` or - :py:class:`str` - - :param wait_codes: tuple of wait codes or wait code - :type wait_codes: :py:class:`tuple` of :py:class:`str` or - :py:class:`str` - """ - self.close() - await self.client.command(None, expected_codes, wait_codes) - - async def __aexit__(self, exc_type, exc, tb): - if exc is None: - await self.finish() - else: - self.close() - - class BaseClient: - - def __init__(self, *, - socket_timeout=None, - connection_timeout=None, - read_speed_limit=None, - write_speed_limit=None, - path_timeout=None, - path_io_factory=pathio.PathIO, - encoding="utf-8", - ssl=None, - parse_list_line_custom=None, - parse_list_line_custom_first=True, - passive_commands=("epsv", "pasv"), - **siosocks_asyncio_kwargs): - self.socket_timeout = socket_timeout - self.connection_timeout = connection_timeout - self.throttle = StreamThrottle.from_limits( + def __init__( + self: "BaseClient", + *, + socket_timeout: float | int | None = None, + connection_timeout: float | int | None = None, + read_speed_limit: int | None = None, + write_speed_limit: int | None = None, + path_timeout: float | int | None = None, + path_io_factory: type[pathio.PathIO] = pathio.PathIO, + encoding: str = "utf-8", + ssl: SSLContext | bool | None = None, + parse_list_line_custom: Callable[[bytes], tuple[pathlib.PurePosixPath, dict[str, str]]] | None = None, + parse_list_line_custom_first: bool = True, + passive_commands: tuple[str, ...] = ("epsv", "pasv"), + **siosocks_asyncio_kwargs: dict[Any, Any], + ): + self.socket_timeout: float | int | None = socket_timeout + self.connection_timeout: float | int | None = connection_timeout + self.throttle: StreamThrottle = StreamThrottle.from_limits( read_speed_limit, write_speed_limit, ) self.path_timeout = path_timeout - self.path_io = path_io_factory(timeout=path_timeout) - self.encoding = encoding - self.stream = None - self.ssl = ssl + self.path_io: pathio.AbstractPathIO = path_io_factory(timeout=path_timeout) + self.encoding: str = encoding + self.stream: ThrottleStreamIO | None = None + self.ssl: SSLContext | bool | None = ssl self.parse_list_line_custom = parse_list_line_custom - self.parse_list_line_custom_first = parse_list_line_custom_first - self._passive_commands = passive_commands - self._open_connection = partial(open_connection, ssl=self.ssl, - **siosocks_asyncio_kwargs) - - async def connect(self, host, port=DEFAULT_PORT): - self.server_host = host - self.server_port = port + self.parse_list_line_custom_first: bool = parse_list_line_custom_first + self._passive_commands: tuple[str, ...] = passive_commands + self._open_connection = partial(open_connection, ssl=self.ssl, **siosocks_asyncio_kwargs) + + async def connect(self, host: str, port: int = DEFAULT_PORT) -> None: + self.server_host: str = host + self.server_port: int = port reader, writer = await asyncio.wait_for( self._open_connection(host, port), self.connection_timeout, @@ -150,14 +113,14 @@ async def connect(self, host, port=DEFAULT_PORT): timeout=self.socket_timeout, ) - def close(self): + def close(self) -> None: """ Close connection. """ if self.stream is not None: self.stream.close() - async def parse_line(self): + async def parse_line(self) -> tuple[Code, str]: """ :py:func:`asyncio.coroutine` @@ -171,7 +134,8 @@ async def parse_line(self): :raises asyncio.TimeoutError: if there where no data for `timeout` period """ - line = await self.stream.readline() + assert self.stream is not None + line: bytes = await self.stream.readline() if not line: self.stream.close() raise ConnectionResetError @@ -179,7 +143,7 @@ async def parse_line(self): logger.debug(s) return Code(s[:3]), s[3:] - async def parse_response(self): + async def parse_response(self) -> tuple[Code, list[str]]: """ :py:func:`asyncio.coroutine` @@ -204,7 +168,12 @@ async def parse_response(self): info.append(curr_code + rest) return code, info - def check_codes(self, expected_codes, received_code, info): + def check_codes( + self, + expected_codes: tuple[str, ...], + received_code: Code, + info: list[str], + ) -> None: """ Checks if any of expected matches received. @@ -223,11 +192,13 @@ def check_codes(self, expected_codes, received_code, info): if not any(map(received_code.matches, expected_codes)): raise errors.StatusCodeError(expected_codes, received_code, info) - async def command(self, - command=None, - expected_codes=(), - wait_codes=(), - censor_after=None): + async def command( + self, + command: str | None = None, + expected_code_or_codes: tuple[str, ...] | str = (), + wait_code_or_codes: tuple[str, ...] | str = (), + censor_after: int | None = None, + ) -> tuple[Code, list[str]]: """ :py:func:`asyncio.coroutine` @@ -238,7 +209,7 @@ async def command(self, 3. Check code for expected. :param command: command line - :type command: :py:class:`str` + :type command: :py:class:`str` or :py:class:`None` :param expected_codes: tuple of expected codes or expected code :type expected_codes: :py:class:`tuple` of :py:class:`str` or @@ -252,8 +223,9 @@ async def command(self, when logging :type censor_after: :py:class:`None` or :py:class:`int` """ - expected_codes = wrap_with_container(expected_codes) - wait_codes = wrap_with_container(wait_codes) + assert self.stream is not None + expected_codes = wrap_with_container(expected_code_or_codes) + wait_codes = wrap_with_container(wait_code_or_codes) if command: if censor_after: # Censor the user's command @@ -271,9 +243,10 @@ async def command(self, if expected_codes: self.check_codes(expected_codes, code, info) return code, info + return Code(""), [] @staticmethod - def parse_epsv_response(s): + def parse_epsv_response(s: str) -> tuple[None, int]: """ Parsing `EPSV` (`message (|||port|)`) response. @@ -289,7 +262,7 @@ def parse_epsv_response(s): return None, port @staticmethod - def parse_pasv_response(s): + def parse_pasv_response(s: str) -> tuple[str, int]: """ Parsing `PASV` server response. @@ -306,7 +279,7 @@ def parse_pasv_response(s): return ip, port @staticmethod - def parse_directory_response(s): + def parse_directory_response(s: str) -> pathlib.PurePosixPath: """ Parsing directory server response. @@ -320,10 +293,10 @@ def parse_directory_response(s): directory = "" for ch in s: if not start: - if ch == "\"": + if ch == '"': start = True else: - if ch == "\"": + if ch == '"': seq_quotes += 1 else: if seq_quotes == 1: @@ -335,7 +308,7 @@ def parse_directory_response(s): return pathlib.PurePosixPath(directory) @staticmethod - def parse_unix_mode(s): + def parse_unix_mode(s: str) -> int: """ Parsing unix mode strings ("rwxr-x--t") into hexacimal notation. @@ -374,7 +347,7 @@ def parse_unix_mode(s): return mode @staticmethod - def format_date_time(d): + def format_date_time(d: datetime.datetime) -> str: """ Formats dates from strptime in a consistent format @@ -386,7 +359,7 @@ def format_date_time(d): return d.strftime("%Y%m%d%H%M00") @classmethod - def parse_ls_date(cls, s, *, now=None): + def parse_ls_date(cls, s: str, *, now: datetime.datetime | None = None) -> str: """ Parsing dates from the ls unix utility. For example, "Nov 18 1958", "Jan 03 2018", and "Nov 18 12:29". @@ -400,14 +373,12 @@ def parse_ls_date(cls, s, *, now=None): try: if now is None: now = datetime.datetime.now() - if s.startswith('Feb 29'): + if s.startswith("Feb 29"): # Need to find the nearest previous leap year prev_leap_year = now.year while not calendar.isleap(prev_leap_year): prev_leap_year -= 1 - d = datetime.datetime.strptime( - f"{prev_leap_year} {s}", "%Y %b %d %H:%M" - ) + d = datetime.datetime.strptime(f"{prev_leap_year} {s}", "%Y %b %d %H:%M") # Check if it's next leap year diff = (now - d).total_seconds() if diff > TWO_YEARS_IN_SECONDS: @@ -424,18 +395,18 @@ def parse_ls_date(cls, s, *, now=None): d = datetime.datetime.strptime(s, "%b %d %Y") return cls.format_date_time(d) - def parse_list_line_unix(self, b): + def parse_list_line_unix(self, b: bytes) -> tuple[pathlib.PurePosixPath, dict[str, str]]: """ Attempt to parse a LIST line (similar to unix ls utility). :param b: response line - :type b: :py:class:`bytes` or :py:class:`str` + :type b: :py:class:`bytes` :return: (path, info) :rtype: (:py:class:`pathlib.PurePosixPath`, :py:class:`dict`) """ s = b.decode(encoding=self.encoding).rstrip() - info = {} + info: dict[str, Any] = {} if s[0] == "-": info["type"] = "file" elif s[0] == "d": @@ -472,28 +443,28 @@ def parse_list_line_unix(self, b): s = s[12:].strip() if info["type"] == "link": i = s.rindex(" -> ") - link_dst = s[i + 4:] + link_dst = s[i + 4 :] link_src = s[:i] - i = -2 if link_dst[-1] == "\'" or link_dst[-1] == "\"" else -1 + i = -2 if link_dst[-1] == "'" or link_dst[-1] == '"' else -1 info["type"] = "dir" if link_dst[i] == "/" else "file" s = link_src return pathlib.PurePosixPath(s), info - def parse_list_line_windows(self, b): + def parse_list_line_windows(self, b: bytes) -> tuple[pathlib.PurePosixPath, dict[str, str]]: """ Parsing Microsoft Windows `dir` output :param b: response line - :type b: :py:class:`bytes` or :py:class:`str` + :type b: :py:class:`bytes` :return: (path, info) :rtype: (:py:class:`pathlib.PurePosixPath`, :py:class:`dict`) """ line = b.decode(encoding=self.encoding).rstrip("\r\n") - date_time_end = line.index("M") - date_time_str = line[:date_time_end + 1].strip().split(" ") - date_time_str = " ".join([x for x in date_time_str if len(x) > 0]) - line = line[date_time_end + 1:].lstrip() + date_time_end: int = line.index("M") + date_time_str_ = line[: date_time_end + 1].strip().split(" ") + date_time_str = " ".join([x for x in date_time_str_ if len(x) > 0]) + line = line[date_time_end + 1 :].lstrip() with setlocale("C"): strptime = datetime.datetime.strptime date_time = strptime(date_time_str, "%m/%d/%Y %I:%M %p") @@ -515,19 +486,19 @@ def parse_list_line_windows(self, b): raise ValueError return pathlib.PurePosixPath(filename), info - def parse_list_line(self, b): + def parse_list_line(self, b: bytes) -> tuple[pathlib.PurePosixPath, dict[str, str]]: """ Parse LIST response with both Microsoft Windows® parser and UNIX parser :param b: response line - :type b: :py:class:`bytes` or :py:class:`str` + :type b: :py:class:`bytes` :return: (path, info) :rtype: (:py:class:`pathlib.PurePosixPath`, :py:class:`dict`) """ ex = [] - parsers = [ + parsers: list[Callable[[bytes], tuple[pathlib.PurePosixPath, dict[str, str]]] | None] = [ self.parse_list_line_unix, self.parse_list_line_windows, ] @@ -544,7 +515,7 @@ def parse_list_line(self, b): ex.append(e) raise ValueError("All parsers failed to parse", b, ex) - def parse_mlsx_line(self, b): + def parse_mlsx_line(self, b: bytes | str) -> tuple[pathlib.PurePosixPath, dict[str, str]]: """ Parsing MLS(T|D) response. @@ -558,7 +529,7 @@ def parse_mlsx_line(self, b): s = b.decode(encoding=self.encoding) else: s = b - line = s.rstrip() + line: str = s.rstrip() facts_found, _, name = line.partition(" ") entry = {} for fact in facts_found[:-1].split(";"): @@ -567,6 +538,62 @@ def parse_mlsx_line(self, b): return pathlib.PurePosixPath(name), entry +class DataConnectionThrottleStreamIO(ThrottleStreamIO): + """ + Add `finish` method to :py:class:`aioftp.ThrottleStreamIO`, which is + specific for data connection. This requires `client`. + + :param client: client class, which have :py:meth:`aioftp.Client.command` + :type client: :py:class:`aioftp.BaseClient` + + :param *args: positional arguments passed to + :py:class:`aioftp.ThrottleStreamIO` + + :param **kwargs: keyword arguments passed to + :py:class:`aioftp.ThrottleStreamIO` + """ + + def __init__(self, client: BaseClient, *args: list[Any], **kwargs: dict[Any, Any]): + super().__init__(*args, **kwargs) + self.client: BaseClient = client + + async def finish( + self, + expected_codes: tuple[str, ...] | str = "2xx", + wait_codes: tuple[str, ...] | str = "1xx", + ) -> None: + """ + :py:func:`asyncio.coroutine` + + Close connection and wait for `expected_codes` response from server + passing `wait_codes`. + + :param expected_codes: tuple of expected codes or expected code + :type expected_codes: :py:class:`tuple` of :py:class:`str` or + :py:class:`str` + + :param wait_codes: tuple of wait codes or wait code + :type wait_codes: :py:class:`tuple` of :py:class:`str` or + :py:class:`str` + """ + self.close() + await self.client.command(None, expected_codes, wait_codes) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if exc_val is None: + await self.finish() + else: + self.close() + + +T = TypeVar("T", bound="Client") + + class Client(BaseClient): """ FTP client. @@ -612,7 +639,8 @@ class Client(BaseClient): :type parse_list_line_custom_first: :py:class:`bool` :param **siosocks_asyncio_kwargs: siosocks key-word only arguments """ - async def connect(self, host, port=DEFAULT_PORT): + + async def connect(self, host: str, port: int = DEFAULT_PORT) -> list[str]: """ :py:func:`asyncio.coroutine` @@ -628,8 +656,12 @@ async def connect(self, host, port=DEFAULT_PORT): code, info = await self.command(None, "220", "120") return info - async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD, - account=DEFAULT_ACCOUNT): + async def login( + self, + user: str = DEFAULT_USER, + password: str = DEFAULT_PASSWORD, + account: str = DEFAULT_ACCOUNT, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -656,10 +688,9 @@ async def login(self, user=DEFAULT_USER, password=DEFAULT_PASSWORD, cmd = "ACCT " + account else: raise errors.StatusCodeError("33x", code, info) - code, info = await self.command(cmd, ("230", "33x"), - censor_after=censor_after) + code, info = await self.command(cmd, ("230", "33x"), censor_after=censor_after) - async def get_current_directory(self): + async def get_current_directory(self) -> pathlib.PurePosixPath: """ :py:func:`asyncio.coroutine` @@ -671,7 +702,7 @@ async def get_current_directory(self): directory = self.parse_directory_response(info[-1]) return directory - async def change_directory(self, path=".."): + async def change_directory(self, path: str | pathlib.PurePosixPath = "..") -> None: """ :py:func:`asyncio.coroutine` @@ -687,7 +718,7 @@ async def change_directory(self, path=".."): cmd = "CWD " + str(path) await self.command(cmd, "2xx") - async def make_directory(self, path, *, parents=True): + async def make_directory(self, path: str | pathlib.PurePosixPath, *, parents: bool = True) -> None: """ :py:func:`asyncio.coroutine` @@ -710,7 +741,7 @@ async def make_directory(self, path, *, parents=True): for path in need_create: await self.command("MKD " + str(path), "257") - async def remove_directory(self, path): + async def remove_directory(self, path: str | pathlib.PurePosixPath) -> None: """ :py:func:`asyncio.coroutine` @@ -721,7 +752,13 @@ async def remove_directory(self, path): """ await self.command("RMD " + str(path), "250") - def list(self, path="", *, recursive=False, raw_command=None): + def list( + self, + path: str | pathlib.PurePosixPath = "", + *, + recursive: bool = False, + raw_command: str | None = None, + ) -> Awaitable[list[tuple[pathlib.PurePosixPath, dict[str, str]]]]: """ :py:func:`asyncio.coroutine` @@ -755,15 +792,16 @@ def list(self, path="", *, recursive=False, raw_command=None): >>> stats = await client.list() """ + class AsyncLister(AsyncListerMixin): - stream = None + stream: DataConnectionThrottleStreamIO | None = None - async def _new_stream(cls, local_path): + async def _new_stream(cls, local_path: str | pathlib.PurePosixPath) -> DataConnectionThrottleStreamIO: cls.path = local_path cls.parse_line = self.parse_mlsx_line if raw_command not in [None, "MLSD", "LIST"]: - raise ValueError("raw_command must be one of MLSD or " - f"LIST, but got {raw_command}") + raise ValueError("raw_command must be one of MLSD or " f"LIST, but got {raw_command}") + # assert raw_command in [None, "MLSD", "LIST"] if raw_command in [None, "MLSD"]: try: command = ("MLSD " + str(cls.path)).strip() @@ -772,20 +810,19 @@ async def _new_stream(cls, local_path): code = e.received_codes[-1] if not code.matches("50x") or raw_command is not None: raise - if raw_command in [None, "LIST"]: - cls.parse_line = self.parse_list_line - command = ("LIST " + str(cls.path)).strip() - return await self.get_stream(command, "1xx") + cls.parse_line = self.parse_list_line + command = ("LIST " + str(cls.path)).strip() + return await self.get_stream(command, "1xx") - def __aiter__(cls): - cls.directories = collections.deque() + def __aiter__(cls) -> Self: + cls.directories: collections.deque[tuple[pathlib.PurePosixPath, dict[str, str]]] = collections.deque() return cls - async def __anext__(cls): + async def __anext__(cls) -> tuple[pathlib.PurePosixPath, dict[str, str]]: if cls.stream is None: cls.stream = await cls._new_stream(path) while True: - line = await cls.stream.readline() + line: bytes = await cls.stream.readline() while not line: await cls.stream.finish() if cls.directories: @@ -806,7 +843,7 @@ async def __anext__(cls): return AsyncLister() - async def stat(self, path): + async def stat(self, path: str | pathlib.PurePosixPath) -> dict: """ :py:func:`asyncio.coroutine` @@ -821,23 +858,33 @@ async def stat(self, path): path = pathlib.PurePosixPath(path) try: code, info = await self.command("MLST " + str(path), "2xx") - name, info = self.parse_mlsx_line(info[1].lstrip()) - return info + name, info_ = self.parse_mlsx_line(info[1].lstrip()) + return info_ except errors.StatusCodeError as e: if not e.received_codes[-1].matches("50x"): raise - for p, info in await self.list(path.parent): + for p, info_ in await self.list(path.parent): if p.name == path.name: - return info + return info_ else: raise errors.StatusCodeError( Code("2xx"), Code("550"), - "path does not exists", + info=["path does not exists"], ) - async def is_file(self, path): + async def size(self, path: str | pathlib.PurePosixPath) -> int | None: + path = pathlib.PurePosixPath(path) + try: + code, info = await self.command("SIZE " + str(path), "2xx") + return int(info[0]) + except errors.StatusCodeError as e: + if not e.received_codes[-1].matches("50x"): + raise + return None + + async def is_file(self, path: str | pathlib.PurePosixPath) -> bool: """ :py:func:`asyncio.coroutine` @@ -851,7 +898,7 @@ async def is_file(self, path): info = await self.stat(path) return info["type"] == "file" - async def is_dir(self, path): + async def is_dir(self, path: str | pathlib.PurePosixPath) -> bool: """ :py:func:`asyncio.coroutine` @@ -865,7 +912,7 @@ async def is_dir(self, path): info = await self.stat(path) return info["type"] == "dir" - async def exists(self, path): + async def exists(self, path: str | pathlib.PurePosixPath) -> bool: """ :py:func:`asyncio.coroutine` @@ -884,7 +931,11 @@ async def exists(self, path): return False raise - async def rename(self, source, destination): + async def rename( + self, + source: str | pathlib.PurePosixPath, + destination: str | pathlib.PurePosixPath, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -899,7 +950,7 @@ async def rename(self, source, destination): await self.command("RNFR " + str(source), "350") await self.command("RNTO " + str(destination), "2xx") - async def remove_file(self, path): + async def remove_file(self, path: str | pathlib.PurePosixPath) -> None: """ :py:func:`asyncio.coroutine` @@ -910,7 +961,7 @@ async def remove_file(self, path): """ await self.command("DELE " + str(path), "2xx") - async def remove(self, path): + async def remove(self, path: str | pathlib.PurePosixPath) -> None: """ :py:func:`asyncio.coroutine` @@ -925,12 +976,14 @@ async def remove(self, path): if info["type"] == "file": await self.remove_file(path) elif info["type"] == "dir": - for name, info in (await self.list(path)): + for name, info in await self.list(path): if info["type"] in ("dir", "file"): await self.remove(name) await self.remove_directory(path) - def upload_stream(self, destination, *, offset=0): + def upload_stream( + self, destination: str | pathlib.PurePosixPath, *, offset: int = 0 + ) -> DataConnectionThrottleStreamIO: """ Create stream for write data to `destination` file. @@ -948,7 +1001,9 @@ def upload_stream(self, destination, *, offset=0): offset=offset, ) - def append_stream(self, destination, *, offset=0): + def append_stream( + self, destination: str | pathlib.PurePosixPath, *, offset: int = 0 + ) -> DataConnectionThrottleStreamIO: """ Create stream for append (write) data to `destination` file. @@ -966,8 +1021,14 @@ def append_stream(self, destination, *, offset=0): offset=offset, ) - async def upload(self, source, destination="", *, write_into=False, - block_size=DEFAULT_BLOCK_SIZE): + async def upload( + self, + source: str | pathlib.Path, + destination: str | pathlib.PurePosixPath = "", + *, + write_into: bool = False, + block_size: int = DEFAULT_BLOCK_SIZE, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -994,8 +1055,7 @@ async def upload(self, source, destination="", *, write_into=False, destination = destination / source.name if await self.path_io.is_file(source): await self.make_directory(destination.parent) - async with self.path_io.open(source, mode="rb") as file_in, \ - self.upload_stream(destination) as stream: + async with self.path_io.open(source, mode="rb") as file_in, self.upload_stream(destination) as stream: async for block in file_in.iter_by_block(block_size): await stream.write(block) elif await self.path_io.is_dir(source): @@ -1016,10 +1076,12 @@ async def upload(self, source, destination="", *, write_into=False, path, relative, write_into=True, - block_size=block_size + block_size=block_size, ) - def download_stream(self, source, *, offset=0): + def download_stream( + self, source: str | pathlib.PurePosixPath, *, offset: int = 0 + ) -> DataConnectionThrottleStreamIO: """ :py:func:`asyncio.coroutine` @@ -1035,8 +1097,14 @@ def download_stream(self, source, *, offset=0): """ return self.get_stream("RETR " + str(source), "1xx", offset=offset) - async def download(self, source, destination="", *, write_into=False, - block_size=DEFAULT_BLOCK_SIZE): + async def download( + self, + source: str | pathlib.PurePosixPath, + destination: str | pathlib.Path = "", + *, + write_into: bool = False, + block_size: int = DEFAULT_BLOCK_SIZE, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -1062,21 +1130,18 @@ async def download(self, source, destination="", *, write_into=False, if not write_into: destination = destination / source.name if await self.is_file(source): - await self.path_io.mkdir(destination.parent, - parents=True, exist_ok=True) - async with self.path_io.open(destination, mode="wb") as file_out, \ - self.download_stream(source) as stream: + await self.path_io.mkdir(destination.parent, parents=True, exist_ok=True) + async with self.path_io.open(destination, mode="wb") as file_out, self.download_stream(source) as stream: async for block in stream.iter_by_block(block_size): await file_out.write(block) elif await self.is_dir(source): await self.path_io.mkdir(destination, parents=True, exist_ok=True) - for name, info in (await self.list(source)): + for name, info in await self.list(source): full = destination / name.relative_to(source) if info["type"] in ("file", "dir"): - await self.download(name, full, write_into=True, - block_size=block_size) + await self.download(name, full, write_into=True, block_size=block_size) - async def quit(self): + async def quit(self) -> None: """ :py:func:`asyncio.coroutine` @@ -1085,18 +1150,21 @@ async def quit(self): await self.command("QUIT", "2xx") self.close() - async def _do_epsv(self): + async def _do_epsv(self) -> tuple[None, int]: code, info = await self.command("EPSV", "229") ip, port = self.parse_epsv_response(info[-1]) return ip, port - async def _do_pasv(self): + async def _do_pasv(self) -> tuple[str, int]: code, info = await self.command("PASV", "227") ip, port = self.parse_pasv_response(info[-1]) return ip, port - async def get_passive_connection(self, conn_type="I", - commands=None): + async def get_passive_connection( + self, + conn_type: Literal["I", "A", "E", "L"] = "I", + commands: tuple[str, ...] | None = None, + ) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: """ :py:func:`asyncio.coroutine` @@ -1139,7 +1207,12 @@ async def get_passive_connection(self, conn_type="I", return reader, writer @async_enterable - async def get_stream(self, *command_args, conn_type="I", offset=0): + async def get_stream( + self, + *command_args, + conn_type: Literal["I", "A", "E", "L"] = "I", + offset: int = 0, + ) -> DataConnectionThrottleStreamIO: """ :py:func:`asyncio.coroutine` @@ -1169,7 +1242,7 @@ async def get_stream(self, *command_args, conn_type="I", offset=0): ) return stream - async def abort(self, *, wait=True): + async def abort(self, *, wait: bool = True) -> None: """ :py:func:`asyncio.coroutine` @@ -1185,9 +1258,15 @@ async def abort(self, *, wait=True): @classmethod @contextlib.asynccontextmanager - async def context(cls, host, port=DEFAULT_PORT, user=DEFAULT_USER, - password=DEFAULT_PASSWORD, account=DEFAULT_ACCOUNT, - **kwargs): + async def context( + cls: type[T], + host: str, + port: int = DEFAULT_PORT, + user: str = DEFAULT_USER, + password: str = DEFAULT_PASSWORD, + account: str = DEFAULT_ACCOUNT, + **kwargs: dict[Any, Any], + ) -> AsyncIterable[T]: """ Classmethod async context manager. This create :py:class:`aioftp.Client`, make async call to diff --git a/aioftp/common.py b/aioftp/common.py index 024d5db..5facb95 100644 --- a/aioftp/common.py +++ b/aioftp/common.py @@ -1,34 +1,42 @@ +from __future__ import annotations + import abc import asyncio import collections import functools import locale import threading +from collections.abc import Awaitable, Callable, Iterator from contextlib import contextmanager +from types import TracebackType +from typing import Any, Self __all__ = ( - "with_timeout", - "StreamIO", - "Throttle", - "StreamThrottle", - "ThrottleStreamIO", - "END_OF_LINE", - "DEFAULT_BLOCK_SIZE", - "wrap_with_container", - "AsyncStreamIterator", "AbstractAsyncLister", - "AsyncListerMixin", "async_enterable", + "AsyncListerMixin", + "AsyncStreamIterator", + "Connection", + "DEFAULT_ACCOUNT", + "DEFAULT_BLOCK_SIZE", + "DEFAULT_PASSWORD", "DEFAULT_PORT", "DEFAULT_USER", - "DEFAULT_PASSWORD", - "DEFAULT_ACCOUNT", + "END_OF_LINE", "setlocale", + "StreamIO", + "StreamThrottle", + "Throttle", + "ThrottleStreamIO", + "with_timeout", + "wrap_with_container", ) END_OF_LINE = "\r\n" -DEFAULT_BLOCK_SIZE = 8192 +DEFAULT_BLOCK_SIZE: int = 8192 +DEFAULT_MAXIMUM_CONNECTIONS: int = 512 +DEFAULT_MAXIMUM_CONNECTIONS_PER_USER: int = 10 DEFAULT_PORT = 21 DEFAULT_USER = "anonymous" @@ -38,22 +46,91 @@ TWO_YEARS_IN_SECONDS = ((365 * 3 + 366) * 24 * 60 * 60) / 2 -def _now(): +class Connection(collections.defaultdict): + """ + Connection state container for transparent work with futures for async + wait + + :param kwargs: initialization parameters + + Container based on :py:class:`collections.defaultdict`, which holds + :py:class:`asyncio.Future` as default factory. There is two layers of + abstraction: + + * Low level based on simple dictionary keys to attributes mapping and + available at Connection.future. + * High level based on futures result and dictionary keys to attributes + mapping and available at Connection. + + To clarify, here is groups of equal expressions + :: + + >>> connection.future.foo + >>> connection["foo"] + + >>> connection.foo + >>> connection["foo"].result() + + >>> del connection.future.foo + >>> del connection.foo + >>> del connection["foo"] + """ + + __slots__ = ("future",) + + class Container: + def __init__(self, storage: dict[str, Any]) -> None: + self.storage = storage + + def __getattr__(self, name: str) -> Any: + return self.storage[name] + + def __delattr__(self, name: str) -> None: + self.storage.pop(name) + + def __init__(self, **kwargs: dict[str, Any]) -> None: + super().__init__(asyncio.Future) + self.future = Connection.Container(self) + for k, v in kwargs.items(): + self[k].set_result(v) + + def __getattr__(self, name: str) -> Any: + if name in self: + return self[name].result() + else: + raise AttributeError(f"{name!r} not in storage") + + def __setattr__(self, name: str, value: Any) -> None: + if name in Connection.__slots__: + super().__setattr__(name, value) + else: + if self[name].done(): + self[name] = super().default_factory() + self[name].set_result(value) + + def __delattr__(self, name: str) -> None: + if name in self: + self.pop(name) + + +def _now() -> float: return asyncio.get_running_loop().time() -def _with_timeout(name): +def _with_timeout(name: str): def decorator(f): @functools.wraps(f) def wrapper(cls, *args, **kwargs): coro = f(cls, *args, **kwargs) timeout = getattr(cls, name) return asyncio.wait_for(coro, timeout) + return wrapper + return decorator -def with_timeout(name): +def with_timeout(name: str): """ Method decorator, wraps method with :py:func:`asyncio.wait_for`. `timeout` argument takes from `name` decorator argument or "timeout". @@ -96,14 +173,13 @@ def with_timeout(name): class AsyncStreamIterator: - - def __init__(self, read_coro): + def __init__(self, read_coro: Callable[[], Awaitable[bytes]]): self.read_coro = read_coro - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> bytes: data = await self.read_coro() if data: return data @@ -111,7 +187,7 @@ async def __anext__(self): raise StopAsyncIteration -class AsyncListerMixin: +class AsyncListerMixin(abc.ABC): """ Add ability to `async for` context to collect data to list via await. @@ -121,7 +197,8 @@ class AsyncListerMixin: ... ... >>> results = await Context(...) """ - async def _to_list(self): + + async def _to_list(self) -> Any: items = [] async for item in self: items.append(item) @@ -130,6 +207,10 @@ async def _to_list(self): def __await__(self): return self._to_list().__await__() + @abc.abstractmethod + def __aiter__(self): + pass + class AbstractAsyncLister(AsyncListerMixin, abc.ABC): """ @@ -159,6 +240,7 @@ class AbstractAsyncLister(AsyncListerMixin, abc.ABC): >>> result [block, block, block, ...] """ + def __init__(self, *, timeout=None): super().__init__() self.timeout = timeout @@ -170,7 +252,7 @@ def __aiter__(self): @abc.abstractmethod async def __anext__(self): """ - :py:func:`asyncio.coroutine` + : py: func: `asyncio.coroutine` Abstract method """ @@ -210,11 +292,10 @@ def async_enterable(f): ... # do """ + @functools.wraps(f) def wrapper(*args, **kwargs): - class AsyncEnterableInstance: - async def __aenter__(self): self.context = await f(*args, **kwargs) return await self.context.__aenter__() @@ -257,24 +338,32 @@ class StreamIO: `timeout` :type write_timeout: :py:class:`int`, :py:class:`float` or :py:class:`None` """ - def __init__(self, reader, writer, *, timeout=None, read_timeout=None, - write_timeout=None): - self.reader = reader - self.writer = writer - self.read_timeout = read_timeout or timeout - self.write_timeout = write_timeout or timeout + + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + *, + timeout: float | int | None = None, + read_timeout: float | int | None = None, + write_timeout: float | int | None = None, + ): + self.reader: asyncio.StreamReader = reader + self.writer: asyncio.StreamWriter = writer + self.read_timeout: float | int | None = read_timeout or timeout + self.write_timeout: float | int | None = write_timeout or timeout @with_timeout("read_timeout") - async def readline(self): + async def readline(self) -> bytes: """ - :py:func:`asyncio.coroutine` + : py: func: `asyncio.coroutine` - Proxy for :py:meth:`asyncio.StreamReader.readline`. + Proxy for: py: meth: `asyncio.StreamReader.readline`. """ return await self.reader.readline() @with_timeout("read_timeout") - async def read(self, count=-1): + async def read(self, count: int = -1) -> bytes: """ :py:func:`asyncio.coroutine` @@ -330,7 +419,7 @@ class Throttle: :type reset_rate: :py:class:`int` or :py:class:`float` """ - def __init__(self, *, limit=None, reset_rate=10): + def __init__(self, *, limit: int | None = None, reset_rate=10): self._limit = limit self.reset_rate = reset_rate self._start = None @@ -338,17 +427,16 @@ def __init__(self, *, limit=None, reset_rate=10): async def wait(self): """ - :py:func:`asyncio.coroutine` + : py: func: `asyncio.coroutine` Wait until can do IO """ - if self._limit is not None and self._limit > 0 and \ - self._start is not None: + if self._limit is not None and self._limit > 0 and self._start is not None: now = _now() end = self._start + self._sum / self._limit await asyncio.sleep(max(0, end - now)) - def append(self, data, start): + def append(self, data: bytes, start: float) -> None: """ Count `data` for throttle @@ -368,14 +456,14 @@ def append(self, data, start): self._sum += len(data) @property - def limit(self): + def limit(self) -> int | None: """ Throttle limit """ return self._limit @limit.setter - def limit(self, value): + def limit(self, value: int | None) -> None: """ Set throttle limit @@ -386,18 +474,17 @@ def limit(self, value): self._start = None self._sum = 0 - def clone(self): + def clone(self) -> Throttle: """ Clone throttle without memory """ return Throttle(limit=self._limit, reset_rate=self.reset_rate) - def __repr__(self): - return f"{self.__class__.__name__}(limit={self._limit!r}, " \ - f"reset_rate={self.reset_rate!r})" + def __repr__(self) -> str: + return f"{self.__class__.__name__}(limit={self._limit!r}, " f"reset_rate={self.reset_rate!r})" -class StreamThrottle(collections.namedtuple("StreamThrottle", "read write")): +class StreamThrottle: # collections.namedtuple("StreamThrottle", "read write") """ Stream throttle with `read` and `write` :py:class:`aioftp.Throttle` @@ -407,17 +494,23 @@ class StreamThrottle(collections.namedtuple("StreamThrottle", "read write")): :param write: stream write throttle :type write: :py:class:`aioftp.Throttle` """ - def clone(self): + + def __init__(self, read: Throttle, write: Throttle): + self.read: Throttle = read + self.write: Throttle = write + + def clone(self) -> StreamThrottle: """ Clone throttles without memory """ - return StreamThrottle( - read=self.read.clone(), - write=self.write.clone() - ) + return StreamThrottle(read=self.read.clone(), write=self.write.clone()) @classmethod - def from_limits(cls, read_speed_limit=None, write_speed_limit=None): + def from_limits( + cls, + read_speed_limit: int | None = None, + write_speed_limit: int | None = None, + ) -> Self: """ Simple wrapper for creation :py:class:`aioftp.StreamThrottle` @@ -429,8 +522,10 @@ def from_limits(cls, read_speed_limit=None, write_speed_limit=None): :py:class:`None` for unlimited :type write_speed_limit: :py:class:`int` or :py:class:`None` """ - return cls(read=Throttle(limit=read_speed_limit), - write=Throttle(limit=write_speed_limit)) + return cls( + read=Throttle(limit=read_speed_limit), + write=Throttle(limit=write_speed_limit), + ) class ThrottleStreamIO(StreamIO): @@ -464,14 +559,14 @@ def __init__(self, *args, throttles={}, **kwargs): super().__init__(*args, **kwargs) self.throttles = throttles - async def wait(self, name): + async def wait(self, name: str): """ - :py:func:`asyncio.coroutine` + : py: func: `asyncio.coroutine` Wait for all throttles - :param name: name of throttle to acquire ("read" or "write") - :type name: :py:class:`str` + : param name: name of throttle to acquire('read' or 'write') + : type name: : py: class: `str` """ tasks = [] for throttle in self.throttles.values(): @@ -481,7 +576,7 @@ async def wait(self, name): if tasks: await asyncio.wait(tasks) - def append(self, name, data, start): + def append(self, name: str, data: bytes, start: float) -> None: """ Update timeout for all throttles @@ -498,11 +593,11 @@ def append(self, name, data, start): for throttle in self.throttles.values(): getattr(throttle, name).append(data, start) - async def read(self, count=-1): + async def read(self, count: int = -1) -> bytes: """ - :py:func:`asyncio.coroutine` + : py: func: `asyncio.coroutine` - :py:meth:`aioftp.StreamIO.read` proxy + : py: meth: `aioftp.StreamIO.read` proxy """ await self.wait("read") start = _now() @@ -510,19 +605,19 @@ async def read(self, count=-1): self.append("read", data, start) return data - async def readline(self): + async def readline(self) -> bytes: """ - :py:func:`asyncio.coroutine` + : py: func: `asyncio.coroutine` - :py:meth:`aioftp.StreamIO.readline` proxy + : py: meth: `aioftp.StreamIO.readline` proxy """ await self.wait("read") start = _now() - data = await super().readline() + data: bytes = await super().readline() self.append("read", data, start) return data - async def write(self, data): + async def write(self, data: bytes) -> None: """ :py:func:`asyncio.coroutine` @@ -536,10 +631,15 @@ async def write(self, data): async def __aenter__(self): return self - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.close() - def iter_by_line(self): + def iter_by_line(self) -> AsyncStreamIterator: """ Read/iterate stream by line. @@ -552,7 +652,7 @@ def iter_by_line(self): """ return AsyncStreamIterator(self.readline) - def iter_by_block(self, count=DEFAULT_BLOCK_SIZE): + def iter_by_block(self, count: int = DEFAULT_BLOCK_SIZE) -> AsyncStreamIterator: """ Read/iterate stream by block. @@ -570,7 +670,7 @@ def iter_by_block(self, count=DEFAULT_BLOCK_SIZE): @contextmanager -def setlocale(name): +def setlocale(name: str) -> Iterator[str]: """ Context manager with threading lock for set locale on enter, and set it back to original state on exit. diff --git a/aioftp/errors.py b/aioftp/errors.py index 2955cc4..fb37721 100644 --- a/aioftp/errors.py +++ b/aioftp/errors.py @@ -1,12 +1,19 @@ -from . import common +from __future__ import annotations +from types import TracebackType +from typing import TYPE_CHECKING, Any + +from aioftp import common + +if TYPE_CHECKING: + from aioftp.client import Code __all__ = ( "AIOFTPException", - "StatusCodeError", - "PathIsNotAbsolute", - "PathIOError", "NoAvailablePort", + "PathIOError", + "PathIsNotAbsolute", + "StatusCodeError", ) @@ -16,6 +23,9 @@ class AIOFTPException(Exception): """ +EXE_INFO_TYPE = tuple[type[BaseException] | None, BaseException | None, TracebackType | None] + + class StatusCodeError(AIOFTPException): """ Raised for unexpected or "bad" status codes. @@ -41,9 +51,9 @@ class StatusCodeError(AIOFTPException): Exception members are tuples, even for one code. """ - def __init__(self, expected_codes, received_codes, info): - super().__init__(f"Waiting for {expected_codes} but got " - f"{received_codes} {info!r}") + + def __init__(self, expected_codes: tuple[str, ...] | str, received_codes: tuple[Code, ...] | Code, info: list[str]): + super().__init__(f"Waiting for {expected_codes} but got " f"{received_codes} {info!r}") self.expected_codes = common.wrap_with_container(expected_codes) self.received_codes = common.wrap_with_container(received_codes) self.info = info @@ -70,9 +80,10 @@ class PathIOError(AIOFTPException): ... elif ... ... # handle """ - def __init__(self, *args, reason=None, **kwargs): + + def __init__(self, *args: list[Any], reason: EXE_INFO_TYPE | None = None, **kwargs: dict[Any, Any]): super().__init__(*args, **kwargs) - self.reason = reason + self.reason: EXE_INFO_TYPE | None = reason class NoAvailablePort(AIOFTPException, OSError): diff --git a/aioftp/pathio.py b/aioftp/pathio.py index b9252bc..41ec27c 100644 --- a/aioftp/pathio.py +++ b/aioftp/pathio.py @@ -4,24 +4,28 @@ import functools import io import operator +import os import pathlib import stat import sys import time +from io import BufferedRandom +from typing import Any, BinaryIO, List, Self -from . import errors -from .common import ( +from aioftp import errors +from aioftp.common import ( DEFAULT_BLOCK_SIZE, AbstractAsyncLister, AsyncStreamIterator, + Connection, with_timeout, ) __all__ = ( "AbstractPathIO", - "PathIO", "AsyncPathIO", "MemoryPathIO", + "PathIO", "PathIONursery", ) @@ -46,13 +50,14 @@ class AsyncPathIOContext: ... await file.close() """ - def __init__(self, pathio, args, kwargs): + + def __init__(self, pathio: "AbstractPathIO", args: list[Any], kwargs: dict[Any, Any]): self.close = None self.pathio = pathio self.args = args self.kwargs = kwargs - async def __aenter__(self): + async def __aenter__(self) -> Self: self.file = await self.pathio._open(*self.args, **self.kwargs) self.seek = functools.partial(self.pathio.seek, self.file) self.write = functools.partial(self.pathio.write, self.file) @@ -77,12 +82,16 @@ def universal_exception(coro): `NotImplementedError`) with universal exception :py:class:`aioftp.PathIOError` """ + @functools.wraps(coro) async def wrapper(*args, **kwargs): try: return await coro(*args, **kwargs) - except (asyncio.CancelledError, NotImplementedError, - StopAsyncIteration): + except ( + asyncio.CancelledError, + NotImplementedError, + StopAsyncIteration, + ): raise except Exception as exc: raise errors.PathIOError(reason=sys.exc_info()) from exc @@ -91,12 +100,11 @@ async def wrapper(*args, **kwargs): class PathIONursery: - def __init__(self, factory): self.factory = factory self.state = None - def __call__(self, *args, **kwargs): + def __call__(self, *args: list[Any], **kwargs: dict[Any, Any]): instance = self.factory(*args, state=self.state, **kwargs) if self.state is None: self.state = instance.state @@ -108,12 +116,13 @@ def defend_file_methods(coro): Decorator. Raises exception when file methods called with wrapped by :py:class:`aioftp.AsyncPathIOContext` file object. """ + @functools.wraps(coro) - async def wrapper(self, file, *args, **kwargs): + async def wrapper(self, file, *args: list[Any], **kwargs: dict[Any, Any]): if isinstance(file, AsyncPathIOContext): - raise ValueError("Native path io file methods can not be used " - "with wrapped file object") + raise ValueError("Native path io file methods can not be used " "with wrapped file object") return await coro(self, file, *args, **kwargs) + return wrapper @@ -129,9 +138,15 @@ class AbstractPathIO(abc.ABC): :param state: shared pathio state per server """ - def __init__(self, timeout=None, connection=None, state=None): - self.timeout = timeout - self.connection = connection + + def __init__( + self: Self, + timeout: float | int | None = None, + connection: Connection | None = None, + state=None, + ): + self.timeout: float | int | None = timeout + self.connection: Connection | None = connection @property def state(self): @@ -141,7 +156,7 @@ def state(self): @universal_exception @abc.abstractmethod - async def exists(self, path): + async def exists(self, path: pathlib.Path) -> bool: """ :py:func:`asyncio.coroutine` @@ -155,7 +170,7 @@ async def exists(self, path): @universal_exception @abc.abstractmethod - async def is_dir(self, path): + async def is_dir(self, path: pathlib.Path): """ :py:func:`asyncio.coroutine` @@ -169,7 +184,7 @@ async def is_dir(self, path): @universal_exception @abc.abstractmethod - async def is_file(self, path): + async def is_file(self, path: pathlib.Path) -> bool: """ :py:func:`asyncio.coroutine` @@ -183,7 +198,7 @@ async def is_file(self, path): @universal_exception @abc.abstractmethod - async def mkdir(self, path, *, parents=False, exist_ok=False): + async def mkdir(self, path, *, parents=False, exist_ok=False) -> None: """ :py:func:`asyncio.coroutine` @@ -201,7 +216,7 @@ async def mkdir(self, path, *, parents=False, exist_ok=False): @universal_exception @abc.abstractmethod - async def rmdir(self, path): + async def rmdir(self, path: pathlib.Path) -> None: """ :py:func:`asyncio.coroutine` @@ -213,7 +228,7 @@ async def rmdir(self, path): @universal_exception @abc.abstractmethod - async def unlink(self, path): + async def unlink(self, path: pathlib.Path): """ :py:func:`asyncio.coroutine` @@ -224,7 +239,7 @@ async def unlink(self, path): """ @abc.abstractmethod - def list(self, path): + def list(self, path: pathlib.Path) -> AbstractAsyncLister: """ Create instance of subclass of :py:class:`aioftp.AbstractAsyncLister`. You should subclass and implement `__anext__` method @@ -252,7 +267,7 @@ def list(self, path): @universal_exception @abc.abstractmethod - async def stat(self, path): + async def stat(self, path: pathlib.Path): """ :py:func:`asyncio.coroutine` @@ -268,7 +283,7 @@ async def stat(self, path): @universal_exception @abc.abstractmethod - async def _open(self, path, mode): + async def _open(self, path: pathlib.Path, mode: str): """ :py:func:`asyncio.coroutine` @@ -287,7 +302,7 @@ async def _open(self, path, mode): :return: file-object """ - def open(self, *args, **kwargs): + def open(self, *args, **kwargs) -> AsyncPathIOContext: """ Create instance of :py:class:`aioftp.pathio.AsyncPathIOContext`, parameters passed to :py:meth:`aioftp.AbstractPathIO._open` @@ -318,7 +333,7 @@ async def seek(self, file, offset, whence=io.SEEK_SET): @universal_exception @defend_file_methods @abc.abstractmethod - async def write(self, file, data): + async def write(self, file: BinaryIO, data: bytes): """ :py:func:`asyncio.coroutine` @@ -333,7 +348,7 @@ async def write(self, file, data): @universal_exception @defend_file_methods @abc.abstractmethod - async def read(self, file, block_size): + async def read(self, file: BinaryIO, block_size: int) -> bytes: """ :py:func:`asyncio.coroutine` @@ -350,7 +365,7 @@ async def read(self, file, block_size): @universal_exception @defend_file_methods @abc.abstractmethod - async def close(self, file): + async def close(self, file: BinaryIO) -> None: """ :py:func:`asyncio.coroutine` @@ -361,7 +376,7 @@ async def close(self, file): @universal_exception @abc.abstractmethod - async def rename(self, source, destination): + async def rename(self, source: pathlib.Path, destination: pathlib.Path) -> None: """ :py:func:`asyncio.coroutine` @@ -381,31 +396,34 @@ class PathIO(AbstractPathIO): """ @universal_exception - async def exists(self, path): + async def exists(self, path: pathlib.Path) -> bool: return path.exists() @universal_exception - async def is_dir(self, path): + async def is_dir(self, path: pathlib.Path) -> bool: return path.is_dir() @universal_exception - async def is_file(self, path): + async def is_file(self, path: pathlib.Path) -> bool: return path.is_file() @universal_exception - async def mkdir(self, path, *, parents=False, exist_ok=False): + async def mkdir(self, path: pathlib.Path, *, parents: bool = False, exist_ok: bool = False) -> None: return path.mkdir(parents=parents, exist_ok=exist_ok) @universal_exception - async def rmdir(self, path): + async def rmdir(self, path: pathlib.Path) -> None: return path.rmdir() @universal_exception - async def unlink(self, path): + async def unlink(self, path: pathlib.Path) -> None: return path.unlink() - def list(self, path): + @universal_exception + async def size(self, path: pathlib.Path) -> int: + return os.path.getsize(path) + def list(self, path: pathlib.Path) -> AbstractAsyncLister: class Lister(AbstractAsyncLister): iter = None @@ -421,11 +439,11 @@ async def __anext__(self): return Lister(timeout=self.timeout) @universal_exception - async def stat(self, path): + async def stat(self, path: pathlib.Path) -> os.stat_result: return path.stat() @universal_exception - async def _open(self, path, *args, **kwargs): + async def _open(self, path: pathlib.Path, *args: List[Any], **kwargs: dict[Any, Any]) -> BufferedRandom: return path.open(*args, **kwargs) @universal_exception @@ -460,6 +478,7 @@ async def wrapper(self, *args, **kwargs): self.executor, functools.partial(f, self, *args, **kwargs), ) + return wrapper @@ -473,6 +492,7 @@ class AsyncPathIO(AbstractPathIO): :param executor: executor for running blocking tasks :type executor: :py:class:`concurrent.futures.Executor` """ + def __init__(self, *args, executor=None, **kwargs): super().__init__(*args, **kwargs) self.executor = executor @@ -480,19 +500,19 @@ def __init__(self, *args, executor=None, **kwargs): @universal_exception @with_timeout @_blocking_io - def exists(self, path): + def exists(self, path: pathlib.Path): return path.exists() @universal_exception @with_timeout @_blocking_io - def is_dir(self, path): + def is_dir(self, path: pathlib.Path): return path.is_dir() @universal_exception @with_timeout @_blocking_io - def is_file(self, path): + def is_file(self, path: pathlib.Path): return path.is_file() @universal_exception @@ -504,17 +524,16 @@ def mkdir(self, path, *, parents=False, exist_ok=False): @universal_exception @with_timeout @_blocking_io - def rmdir(self, path): + def rmdir(self, path: pathlib.Path): return path.rmdir() @universal_exception @with_timeout @_blocking_io - def unlink(self, path): + def unlink(self, path: pathlib.Path): return path.unlink() - def list(self, path): - + def list(self, path: pathlib.Path): class Lister(AbstractAsyncLister): iter = None @@ -541,9 +560,15 @@ def __anext__(self): @universal_exception @with_timeout @_blocking_io - def stat(self, path): + def stat(self, path: pathlib.Path): return path.stat() + @universal_exception + @with_timeout + @_blocking_io + def size(self, path: pathlib.Path): + return os.path.getsize(path) + @universal_exception @with_timeout @_blocking_io @@ -586,7 +611,6 @@ def rename(self, source, destination): class Node: - def __init__(self, type, name, ctime=None, mtime=None, *, content): self.type = type self.name = name @@ -595,9 +619,11 @@ def __init__(self, type, name, ctime=None, mtime=None, *, content): self.content = content def __repr__(self): - return f"{self.__class__.__name__}(type={self.type!r}, " \ - f"name={self.name!r}, ctime={self.ctime!r}, " \ - f"mtime={self.mtime!r}, content={self.content!r})" + return ( + f"{self.__class__.__name__}(type={self.type!r}, " + f"name={self.name!r}, ctime={self.ctime!r}, " + f"mtime={self.mtime!r}, content={self.content!r})" + ) class MemoryPathIO(AbstractPathIO): @@ -614,12 +640,12 @@ class MemoryPathIO(AbstractPathIO): "st_mtime", "st_nlink", "st_mode", - ) + ), ) def __init__(self, *args, state=None, cwd=None, **kwargs): super().__init__(*args, **kwargs) - self.cwd = pathlib.PurePosixPath(cwd or "/") + self.cwd: pathlib.PurePosixPath = pathlib.PurePosixPath(cwd or "/") if state is None: self.fs = [Node("dir", "/", content=[])] else: @@ -632,12 +658,12 @@ def state(self): def __repr__(self): return repr(self.fs) - def _absolute(self, path): + def _absolute(self, path: pathlib.PurePosixPath) -> pathlib.PurePosixPath: if not path.is_absolute(): path = self.cwd / path return path - def get_node(self, path): + def get_node(self, path: pathlib.PurePosixPath): nodes = self.fs node = None path = self._absolute(path) @@ -653,21 +679,21 @@ def get_node(self, path): return node @universal_exception - async def exists(self, path): + async def exists(self, path: pathlib.PurePosixPath): return self.get_node(path) is not None @universal_exception - async def is_dir(self, path): + async def is_dir(self, path: pathlib.PurePosixPath) -> bool: node = self.get_node(path) return not (node is None or node.type != "dir") @universal_exception - async def is_file(self, path): + async def is_file(self, path: pathlib.PurePosixPath): node = self.get_node(path) return not (node is None or node.type != "file") @universal_exception - async def mkdir(self, path, *, parents=False, exist_ok=False): + async def mkdir(self, path: pathlib.PurePosixPath, *, parents=False, exist_ok=False): path = self._absolute(path) node = self.get_node(path) if node: @@ -697,7 +723,7 @@ async def mkdir(self, path, *, parents=False, exist_ok=False): raise NotADirectoryError @universal_exception - async def rmdir(self, path): + async def rmdir(self, path: pathlib.PurePosixPath): node = self.get_node(path) if node is None: raise FileNotFoundError @@ -713,7 +739,7 @@ async def rmdir(self, path): parent.content.pop(i) @universal_exception - async def unlink(self, path): + async def unlink(self, path: pathlib.PurePosixPath): node = self.get_node(path) if node is None: raise FileNotFoundError @@ -726,8 +752,7 @@ async def unlink(self, path): break parent.content.pop(i) - def list(self, path): - + def list(self, path: pathlib.Path): class Lister(AbstractAsyncLister): iter = None @@ -749,7 +774,7 @@ async def __anext__(cls): return Lister(timeout=self.timeout) @universal_exception - async def stat(self, path): + async def stat(self, path: pathlib.PurePosixPath): node = self.get_node(path) if node is None: raise FileNotFoundError @@ -769,7 +794,19 @@ async def stat(self, path): ) @universal_exception - async def _open(self, path, mode="rb", *args, **kwargs): + async def size(self, path: pathlib.PurePosixPath): + node = self.get_node(path) + if node is None: + raise FileNotFoundError + + if node.type == "file": + size = len(node.content.getbuffer()) + else: + size = 0 + return size + + @universal_exception + async def _open(self, path: pathlib.PurePosixPath, mode: str = "rb", *args, **kwargs): if mode == "rb": node = self.get_node(path) if node is None: @@ -822,7 +859,7 @@ async def close(self, file): pass @universal_exception - async def rename(self, source, destination): + async def rename(self, source: pathlib.PurePosixPath, destination: pathlib.PurePosixPath): if source != destination: sparent = self.get_node(source.parent) dparent = self.get_node(destination.parent) diff --git a/aioftp/server.py b/aioftp/server.py index 723faf0..04924a9 100644 --- a/aioftp/server.py +++ b/aioftp/server.py @@ -1,21 +1,28 @@ +from __future__ import annotations + import abc import asyncio -import collections import enum import errno import functools import logging +import os import pathlib import socket import stat -import sys import time +from collections.abc import Awaitable, Callable, Iterator, Sequence +from typing import Any, Self -from . import errors, pathio -from .common import ( +from aioftp import errors, pathio +from aioftp.common import ( DEFAULT_BLOCK_SIZE, + DEFAULT_MAXIMUM_CONNECTIONS, + DEFAULT_MAXIMUM_CONNECTIONS_PER_USER, END_OF_LINE, HALF_OF_YEAR_IN_SECONDS, + Connection, + StreamIO, StreamThrottle, ThrottleStreamIO, setlocale, @@ -23,25 +30,19 @@ ) __all__ = ( - "Permission", - "User", "AbstractUserManager", - "MemoryUserManager", - "Connection", "AvailableConnections", + "Connection", "ConnectionConditions", + "MemoryUserManager", "PathConditions", "PathPermissions", - "worker", + "Permission", "Server", + "User", + "worker", ) - -IS_PY37_PLUS = sys.version_info[:2] >= (3, 7) -if IS_PY37_PLUS: - get_current_task = asyncio.current_task -else: - get_current_task = asyncio.Task.current_task - +get_current_task = asyncio.current_task logger = logging.getLogger(__name__) @@ -58,21 +59,27 @@ class Permission: :param writable: is writable :type writable: :py:class:`bool` """ - def __init__(self, path="/", *, readable=True, writable=True): - self.path = pathlib.PurePosixPath(path) - self.readable = readable - self.writable = writable - def is_parent(self, other): + def __init__( + self, + path: str | pathlib.PurePosixPath = "/", + *, + readable: bool = True, + writable: bool = True, + ): + self.path: pathlib.PurePosixPath = pathlib.PurePosixPath(path) + self.readable: bool = readable + self.writable: bool = writable + + def is_parent(self, other: pathlib.PurePosixPath) -> bool: try: other.relative_to(self.path) return True except ValueError: return False - def __repr__(self): - return f"{self.__class__.__name__}({self.path!r}, " \ - f"readable={self.readable!r}, writable={self.writable!r})" + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.path!r}, " f"readable={self.readable!r}, writable={self.writable!r})" class User: @@ -113,33 +120,36 @@ class User: connection in bytes per second :type write_speed_limit_per_connection: :py:class:`int` or :py:class:`None` """ - def __init__(self, - login=None, - password=None, *, - base_path=pathlib.Path("."), - home_path=pathlib.PurePosixPath("/"), - permissions=None, - maximum_connections=None, - read_speed_limit=None, - write_speed_limit=None, - read_speed_limit_per_connection=None, - write_speed_limit_per_connection=None): - self.login = login - self.password = password - self.base_path = pathlib.Path(base_path) - self.home_path = pathlib.PurePosixPath(home_path) + + def __init__( + self, + login: str | None = None, + password: str | None = None, + *, + base_path: pathlib.Path = pathlib.Path("."), + home_path: pathlib.PurePosixPath = pathlib.PurePosixPath("/"), + permissions: Sequence[Permission] | None = None, + maximum_connections: int = DEFAULT_MAXIMUM_CONNECTIONS_PER_USER, + read_speed_limit: int | None = None, + write_speed_limit: int | None = None, + read_speed_limit_per_connection: int | None = None, + write_speed_limit_per_connection: int | None = None, + ): + self.login: str | None = login + self.password: str | None = password + self.base_path: pathlib.Path = pathlib.Path(base_path) + self.home_path: pathlib.PurePosixPath = pathlib.PurePosixPath(home_path) if not self.home_path.is_absolute(): raise errors.PathIsNotAbsolute(home_path) - self.permissions = permissions or [Permission()] - self.maximum_connections = maximum_connections - self.read_speed_limit = read_speed_limit - self.write_speed_limit = write_speed_limit - self.read_speed_limit_per_connection = read_speed_limit_per_connection + self.permissions: Sequence[Permission] = permissions or [Permission()] + self.maximum_connections: int = maximum_connections + self.read_speed_limit: int | None = read_speed_limit + self.write_speed_limit: int | None = write_speed_limit + self.read_speed_limit_per_connection: int | None = read_speed_limit_per_connection # damn 80 symbols - self.write_speed_limit_per_connection = \ - write_speed_limit_per_connection + self.write_speed_limit_per_connection: int | None = write_speed_limit_per_connection - async def get_permissions(self, path): + async def get_permissions(self, path: str | pathlib.PurePosixPath) -> Permission: """ Return nearest parent permission for `path`. @@ -148,27 +158,38 @@ async def get_permissions(self, path): :rtype: :py:class:`aioftp.Permission` """ - path = pathlib.PurePosixPath(path) - parents = filter(lambda p: p.is_parent(path), self.permissions) + path_ = pathlib.PurePosixPath(path) + parents = filter(lambda p: p.is_parent(path_), self.permissions) perm = min( parents, - key=lambda p: len(path.relative_to(p.path).parts), + key=lambda p: len(path_.relative_to(p.path).parts), default=Permission(), ) return perm - def __repr__(self): - return f"{self.__class__.__name__}({self.login!r}, " \ - f"{self.password!r}, base_path={self.base_path!r}, " \ - f"home_path={self.home_path!r}, " \ - f"permissions={self.permissions!r}, " \ - f"maximum_connections={self.maximum_connections!r}, " \ - f"read_speed_limit={self.read_speed_limit!r}, " \ - f"write_speed_limit={self.write_speed_limit!r}, " \ - f"read_speed_limit_per_connection=" \ - f"{self.read_speed_limit_per_connection!r}, " \ - f"write_speed_limit_per_connection=" \ - f"{self.write_speed_limit_per_connection!r})" + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.login!r}, " + f"{self.password!r}, base_path={self.base_path!r}, " + f"home_path={self.home_path!r}, " + f"permissions={self.permissions!r}, " + f"maximum_connections={self.maximum_connections!r}, " + f"read_speed_limit={self.read_speed_limit!r}, " + f"write_speed_limit={self.write_speed_limit!r}, " + f"read_speed_limit_per_connection=" + f"{self.read_speed_limit_per_connection!r}, " + f"write_speed_limit_per_connection=" + f"{self.write_speed_limit_per_connection!r})" + ) + + enum.Enum("UserManagerResponse", "OK PASSWORD_REQUIRED ERROR") + + +class GetUserResponse(enum.Enum): + # "UserManagerResponse", + OK = 1 + PASSWORD_REQUIRED = 2 + ERROR = 3 class AbstractUserManager(abc.ABC): @@ -179,16 +200,11 @@ class AbstractUserManager(abc.ABC): :type timeout: :py:class:`float`, :py:class:`int` or :py:class:`None` """ - GetUserResponse = enum.Enum( - "UserManagerResponse", - "OK PASSWORD_REQUIRED ERROR" - ) - - def __init__(self, *, timeout=None): + def __init__(self, *, timeout: float | int | None = None): self.timeout = timeout @abc.abstractmethod - async def get_user(self, login): + async def get_user(self, login: str) -> tuple[GetUserResponse, User | None, str]: """ :py:func:`asyncio.coroutine` @@ -199,7 +215,7 @@ async def get_user(self, login): """ @abc.abstractmethod - async def authenticate(self, user, password): + async def authenticate(self, user: User, password: str) -> bool: """ :py:func:`asyncio.coroutine` @@ -214,7 +230,7 @@ async def authenticate(self, user, password): :rtype: :py:class:`bool` """ - async def notify_logout(self, user): + async def notify_logout(self, user: User) -> None: """ :py:func:`asyncio.coroutine` @@ -233,16 +249,16 @@ class MemoryUserManager(AbstractUserManager): :type users: :py:class:`list`, :py:class:`tuple`, etc. of :py:class:`aioftp.User` """ - def __init__(self, users, *args, **kwargs): + + def __init__(self, users: Sequence[User], *args: list[Any], **kwargs: dict[Any, Any]): super().__init__(*args, **kwargs) - self.users = users or [User()] - self.available_connections = dict( - (user, AvailableConnections(user.maximum_connections)) - for user in self.users + self.users: Sequence[User] = users or [User()] + self.available_connections: dict[User | None, AvailableConnections] = dict( + (user, AvailableConnections(user.maximum_connections)) for user in self.users ) - async def get_user(self, login): - user = None + async def get_user(self, login: str | None) -> tuple[GetUserResponse, User | None, str]: + user: User | None = None for u in self.users: if u.login is None and user is None: user = u @@ -250,100 +266,32 @@ async def get_user(self, login): user = u break if user is None: - state = AbstractUserManager.GetUserResponse.ERROR + state = GetUserResponse.ERROR info = "no such username" elif self.available_connections[user].locked(): - state = AbstractUserManager.GetUserResponse.ERROR + state = GetUserResponse.ERROR info = f"too much connections for {user.login or 'anonymous'!r}" elif user.login is None: - state = AbstractUserManager.GetUserResponse.OK + state = GetUserResponse.OK info = "anonymous login" elif user.password is None: - state = AbstractUserManager.GetUserResponse.OK + state = GetUserResponse.OK info = "login without password" else: - state = AbstractUserManager.GetUserResponse.PASSWORD_REQUIRED + state = GetUserResponse.PASSWORD_REQUIRED info = "password required" - if state != AbstractUserManager.GetUserResponse.ERROR: + if state != GetUserResponse.ERROR: self.available_connections[user].acquire() return state, user, info - async def authenticate(self, user, password): + async def authenticate(self, user: User, password: str) -> bool: return user.password == password - async def notify_logout(self, user): + async def notify_logout(self, user: User) -> None: self.available_connections[user].release() -class Connection(collections.defaultdict): - """ - Connection state container for transparent work with futures for async - wait - - :param kwargs: initialization parameters - - Container based on :py:class:`collections.defaultdict`, which holds - :py:class:`asyncio.Future` as default factory. There is two layers of - abstraction: - - * Low level based on simple dictionary keys to attributes mapping and - available at Connection.future. - * High level based on futures result and dictionary keys to attributes - mapping and available at Connection. - - To clarify, here is groups of equal expressions - :: - - >>> connection.future.foo - >>> connection["foo"] - - >>> connection.foo - >>> connection["foo"].result() - - >>> del connection.future.foo - >>> del connection.foo - >>> del connection["foo"] - """ - - __slots__ = ("future",) - - class Container: - - def __init__(self, storage): - self.storage = storage - - def __getattr__(self, name): - return self.storage[name] - - def __delattr__(self, name): - self.storage.pop(name) - - def __init__(self, **kwargs): - super().__init__(asyncio.Future) - self.future = Connection.Container(self) - for k, v in kwargs.items(): - self[k].set_result(v) - - def __getattr__(self, name): - if name in self: - return self[name].result() - else: - raise AttributeError(f"{name!r} not in storage") - - def __setattr__(self, name, value): - if name in Connection.__slots__: - super().__setattr__(name, value) - else: - if self[name].done(): - self[name] = super().default_factory() - self[name].set_result(value) - - def __delattr__(self, name): - if name in self: - self.pop(name) - - class AvailableConnections: """ Semaphore-like object. Have no blocks, only raises ValueError on bounds @@ -353,10 +301,10 @@ class AvailableConnections: :type value: :py:class:`int` or :py:class:`None` """ - def __init__(self, value=None): + def __init__(self, value: int | None = None): self.value = self.maximum_value = value - def locked(self): + def locked(self) -> bool: """ Returns True if semaphore-like can not be acquired. @@ -364,7 +312,7 @@ def locked(self): """ return self.value == 0 - def acquire(self): + def acquire(self) -> None: """ Acquire, decrementing the internal counter by one. """ @@ -373,10 +321,11 @@ def acquire(self): if self.value < 0: raise ValueError("Too many acquires") - def release(self): + def release(self) -> None: """ Release, incrementing the internal counter by one. """ + assert self.maximum_value is not None if self.value is not None: self.value += 1 if self.value > self.maximum_value: @@ -421,25 +370,31 @@ class ConnectionConditions: ... def foo(self, connection, rest): ... ... """ + user_required = ("user", "no user (use USER firstly)") login_required = ("logged", "not logged in") passive_server_started = ( "passive_server", - "no listen socket created (use PASV firstly)" + "no listen socket created (use PASV firstly)", ) data_connection_made = ("data_connection", "no data connection made") rename_from_required = ("rename_from", "no filename (use RNFR firstly)") - def __init__(self, *fields, wait=False, fail_code="503", fail_info=None): + def __init__( + self, + *fields: list[Any], + wait: bool = False, + fail_code: str = "503", + fail_info: str | None = None, + ): self.fields = fields - self.wait = wait - self.fail_code = fail_code - self.fail_info = fail_info - - def __call__(self, f): + self.wait: bool = wait + self.fail_code: str = fail_code + self.fail_info: str | None = fail_info + def __call__(self, f: Callable[[Server, Connection, str, list[Any]], Awaitable[bool]]) -> Any: @functools.wraps(f) - async def wrapper(cls, connection, rest, *args): + async def wrapper(cls: Server, connection: Connection, rest: str, *args: list[Any]) -> bool: futures = {connection[name]: msg for name, msg in self.fields} aggregate = asyncio.gather(*futures) if self.wait: @@ -483,16 +438,16 @@ class PathConditions: ... def foo(self, connection, path): ... ... """ + path_must_exists = ("exists", False, "path does not exists") path_must_not_exists = ("exists", True, "path already exists") path_must_be_dir = ("is_dir", False, "path is not a directory") path_must_be_file = ("is_file", False, "path is not a file") - def __init__(self, *conditions): + def __init__(self, *conditions: list[tuple[str, bool, str]]): self.conditions = conditions def __call__(self, f): - @functools.wraps(f) async def wrapper(cls, connection, rest, *args): real_path, virtual_path = cls.get_paths(connection, rest) @@ -525,18 +480,18 @@ class PathPermissions: ... def foo(self, connection, path): ... ... """ + readable = "readable" writable = "writable" - def __init__(self, *permissions): + def __init__(self, *permissions: list[str]): self.permissions = permissions - def __call__(self, f): - + def __call__(self, f: Callable[[Server, Connection, str, list[Any]], Awaitable[bool]]) -> Any: @functools.wraps(f) - async def wrapper(cls, connection, rest, *args): + async def wrapper(cls: Server, connection: Connection, rest: str, *args: list[Any]) -> bool: real_path, virtual_path = cls.get_paths(connection, rest) - current_permission = await connection.user.get_permissions( + current_permission: Permission = await connection.user.get_permissions( virtual_path, ) for permission in self.permissions: @@ -544,6 +499,7 @@ async def wrapper(cls, connection, rest, *args): connection.response("550", "permission denied") return True return await f(cls, connection, rest, *args) + return False return wrapper @@ -560,8 +516,9 @@ def worker(f): ... ... """ + @functools.wraps(f) - async def wrapper(cls, connection, rest): + async def wrapper(cls, connection: Connection, rest: str): try: await f(cls, connection, rest) except asyncio.CancelledError: @@ -637,34 +594,36 @@ class Server: Please look :py:meth:`asyncio.loop.create_server` docs. :type ssl: :py:class:`ssl.SSLContext` """ - def __init__(self, - users=None, - *, - block_size=DEFAULT_BLOCK_SIZE, - socket_timeout=None, - idle_timeout=None, - wait_future_timeout=1, - path_timeout=None, - path_io_factory=pathio.PathIO, - maximum_connections=None, - read_speed_limit=None, - write_speed_limit=None, - read_speed_limit_per_connection=None, - write_speed_limit_per_connection=None, - ipv4_pasv_forced_response_address=None, - data_ports=None, - encoding="utf-8", - ssl=None): + + def __init__( + self, + users: Sequence[User] | AbstractUserManager | None = None, + *, + block_size: int = DEFAULT_BLOCK_SIZE, + socket_timeout: float | int | None = None, + idle_timeout: float | int | None = None, + wait_future_timeout: float | int | None = 1, + path_timeout: float | int | None = None, + path_io_factory: type[pathio.AbstractPathIO] = pathio.PathIO, + maximum_connections: int = DEFAULT_MAXIMUM_CONNECTIONS, + read_speed_limit: int | None = None, + write_speed_limit: int | None = None, + read_speed_limit_per_connection: int | None = None, + write_speed_limit_per_connection: int | None = None, + ipv4_pasv_forced_response_address: str | None = None, + data_ports: Iterator[int] | None = None, + encoding: str = "utf-8", + ssl=None, + ): self.block_size = block_size self.socket_timeout = socket_timeout self.idle_timeout = idle_timeout self.wait_future_timeout = wait_future_timeout self.path_io_factory = pathio.PathIONursery(path_io_factory) self.path_timeout = path_timeout - self.ipv4_pasv_forced_response_address = \ - ipv4_pasv_forced_response_address + self.ipv4_pasv_forced_response_address = ipv4_pasv_forced_response_address if data_ports is not None: - self.available_data_ports = asyncio.PriorityQueue() + self.available_data_ports: asyncio.PriorityQueue | None = asyncio.PriorityQueue() for data_port in data_ports: self.available_data_ports.put_nowait((0, data_port)) else: @@ -684,7 +643,7 @@ def __init__(self, read_speed_limit_per_connection, write_speed_limit_per_connection, ) - self.throttle_per_user = {} + self.throttle_per_user: dict[User, StreamThrottle] = {} self.encoding = encoding self.ssl = ssl self.commands_mapping = { @@ -713,9 +672,10 @@ def __init__(self, "syst": self.syst, "type": self.type, "user": self.user, + "size": self.size, } - async def start(self, host=None, port=0, **kwargs): + async def start(self, host: str | None = None, port: int = 0, **kwargs: dict[Any, Any]): """ :py:func:`asyncio.coroutine` @@ -731,15 +691,16 @@ async def start(self, host=None, port=0, **kwargs): :py:func:`asyncio.start_server` """ self._start_server_extra_arguments = kwargs - self.connections = {} + self.connections: dict[ThrottleStreamIO, Connection] = {} self.server_host = host self.server_port = port - self.server = await asyncio.start_server( + self.server: asyncio.base_events.Server = await asyncio.start_server( self.dispatcher, host, port, ssl=self.ssl, **self._start_server_extra_arguments, + # kwds=self._start_server_extra_arguments ) for sock in self.server.sockets: if sock.family in (socket.AF_INET, socket.AF_INET6): @@ -750,7 +711,7 @@ async def start(self, host=None, port=0, **kwargs): self.server_host = host logger.info("serving on %s:%s", host, port) - async def serve_forever(self): + async def serve_forever(self) -> None: """ :py:func:`asyncio.coroutine` @@ -780,13 +741,13 @@ async def run(self, host=None, port=0, **kwargs): await self.close() @property - def address(self): + def address(self) -> tuple[str | None, int]: """ Server listen socket host and port as :py:class:`tuple` """ return self.server_host, self.server_port - async def close(self): + async def close(self) -> None: """ :py:func:`asyncio.coroutine` @@ -800,11 +761,17 @@ async def close(self): logger.debug("waiting for %d tasks", len(tasks)) await asyncio.wait(tasks) - async def write_line(self, stream, line): + async def write_line(self, stream: StreamIO, line: str) -> None: logger.debug(line) await stream.write((line + END_OF_LINE).encode(encoding=self.encoding)) - async def write_response(self, stream, code, lines="", list=False): + async def write_response( + self: Self, + stream: StreamIO, + code: str, + line_or_lines: str | Iterator[str] = "", + list: bool = False, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -816,14 +783,14 @@ async def write_response(self, stream, code, lines="", list=False): :param code: server response code :type code: :py:class:`str` - :param lines: line or lines, which are response information + :param line_or_lines: line or lines, which are response information :type lines: :py:class:`str` or :py:class:`collections.Iterable` :param list: if true, then lines will be sended without code prefix. This is useful for **LIST** FTP command and some others. :type list: :py:class:`bool` """ - lines = wrap_with_container(lines) + lines = wrap_with_container(line_or_lines) write = functools.partial(self.write_line, stream) if list: head, *body, tail = lines @@ -837,7 +804,7 @@ async def write_response(self, stream, code, lines="", list=False): await write(code + "-" + line) await write(code + " " + tail) - async def parse_command(self, stream, censor_commands=("pass",)): + async def parse_command(self, stream: StreamIO, censor_commands: tuple[str, ...] = ("pass",)) -> tuple[str, str]: """ :py:func:`asyncio.coroutine` @@ -866,7 +833,7 @@ async def parse_command(self, stream, censor_commands=("pass",)): return cmd.lower(), rest - async def response_writer(self, stream, response_queue): + async def response_writer(self, stream: StreamIO, response_queue: asyncio.Queue) -> None: """ :py:func:`asyncio.coroutine` @@ -887,7 +854,11 @@ async def response_writer(self, stream, response_queue): finally: response_queue.task_done() - async def dispatcher(self, reader, writer): + async def dispatcher( + self, + reader: asyncio.streams.StreamReader, + writer: asyncio.streams.StreamWriter, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -901,12 +872,12 @@ async def dispatcher(self, reader, writer): writer, throttles=dict( server_global=self.throttle, - server_per_connection=self.throttle_per_connection.clone() + server_per_connection=self.throttle_per_connection.clone(), ), read_timeout=self.idle_timeout, write_timeout=self.socket_timeout, ) - response_queue = asyncio.Queue() + response_queue: asyncio.Queue = asyncio.Queue() connection = Connection( client_host=host, client_port=port, @@ -926,8 +897,7 @@ async def dispatcher(self, reader, writer): restart_offset=0, _dispatcher=get_current_task(), ) - connection.path_io = self.path_io_factory(timeout=self.path_timeout, - connection=connection) + connection.path_io = self.path_io_factory(timeout=self.path_timeout, connection=connection) pending = { asyncio.create_task(self.greeting(connection, "")), asyncio.create_task(self.response_writer(stream, response_queue)), @@ -954,15 +924,11 @@ async def dispatcher(self, reader, writer): return # this is parse_command result elif isinstance(result, tuple): - pending.add( - asyncio.create_task(self.parse_command(stream)) - ) + pending.add(asyncio.create_task(self.parse_command(stream))) cmd, rest = result f = self.commands_mapping.get(cmd) if f is not None: - pending.add( - asyncio.create_task(f(connection, rest)) - ) + pending.add(asyncio.create_task(f(connection, rest))) if cmd not in ("retr", "stor", "appe"): connection.restart_offset = 0 else: @@ -990,16 +956,17 @@ async def dispatcher(self, reader, writer): if connection.acquired: self.available_connections.release() if connection.future.user.done(): - task = asyncio.create_task( - self.user_manager.notify_logout(connection.user) - ) + task = asyncio.create_task(self.user_manager.notify_logout(connection.user)) tasks_to_wait.append(task) self.connections.pop(key) if tasks_to_wait: await asyncio.wait(tasks_to_wait) @staticmethod - def get_paths(connection, path): + def get_paths( + connection: Connection, + path: str | pathlib.PurePosixPath, + ) -> tuple[pathlib.Path, pathlib.PurePosixPath]: """ Return *real* and *virtual* paths, resolves ".." with "up" action. *Real* path is path for path_io, when *virtual* deals with @@ -1014,7 +981,7 @@ def get_paths(connection, path): :return: (real_path, virtual_path) :rtype: (:py:class:`pathlib.Path`, :py:class:`pathlib.PurePosixPath`) """ - virtual_path = pathlib.PurePosixPath(path) + virtual_path: pathlib.PurePosixPath = pathlib.PurePosixPath(path) if not virtual_path.is_absolute(): virtual_path = connection.current_directory / virtual_path resolved_virtual_path = pathlib.PurePosixPath("/") @@ -1033,7 +1000,7 @@ def get_paths(connection, path): resolved_virtual_path = pathlib.PurePosixPath("/") return real_path, resolved_virtual_path - async def greeting(self, connection, rest): + async def greeting(self, connection: Connection, rest: str): if self.available_connections.locked(): ok, code, info = False, "421", "Too many connections" else: @@ -1043,20 +1010,20 @@ async def greeting(self, connection, rest): connection.response(code, info) return ok - async def user(self, connection, rest): + async def user(self, connection: Connection, rest: str): if connection.future.user.done(): await self.user_manager.notify_logout(connection.user) del connection.user del connection.logged state, user, info = await self.user_manager.get_user(rest) - if state == AbstractUserManager.GetUserResponse.OK: + if state == GetUserResponse.OK: code = "230" connection.logged = True connection.user = user - elif state == AbstractUserManager.GetUserResponse.PASSWORD_REQUIRED: + elif state == GetUserResponse.PASSWORD_REQUIRED: code = "331" connection.user = user - elif state == AbstractUserManager.GetUserResponse.ERROR: + elif state == GetUserResponse.ERROR: code = "530" else: message = f"Unknown response {state}" @@ -1076,13 +1043,13 @@ async def user(self, connection, rest): user_per_connection=StreamThrottle.from_limits( connection.user.read_speed_limit_per_connection, connection.user.write_speed_limit_per_connection, - ) + ), ) connection.response(code, info) return True @ConnectionConditions(ConnectionConditions.user_required) - async def pass_(self, connection, rest): + async def pass_(self, connection: Connection, rest: str): if connection.future.logged.done(): code, info = "503", "already logged in" elif await self.user_manager.authenticate(connection.user, rest): @@ -1093,63 +1060,59 @@ async def pass_(self, connection, rest): connection.response(code, info) return True - async def quit(self, connection, rest): + async def quit(self, connection: Connection, rest: str): connection.response("221", "bye") return False @ConnectionConditions(ConnectionConditions.login_required) - async def pwd(self, connection, rest): - code, info = "257", f"\"{connection.current_directory}\"" + async def pwd(self, connection: Connection, rest: str): + code, info = "257", f'"{connection.current_directory}"' connection.response(code, info) return True @ConnectionConditions(ConnectionConditions.login_required) - @PathConditions( - PathConditions.path_must_exists, - PathConditions.path_must_be_dir) + @PathConditions(PathConditions.path_must_exists, PathConditions.path_must_be_dir) @PathPermissions(PathPermissions.readable) - async def cwd(self, connection, rest): + async def cwd(self, connection: Connection, rest: str): real_path, virtual_path = self.get_paths(connection, rest) connection.current_directory = virtual_path connection.response("250", "") return True @ConnectionConditions(ConnectionConditions.login_required) - async def cdup(self, connection, rest): + async def cdup(self, connection: Connection, rest: str): return await self.cwd(connection, connection.current_directory.parent) @ConnectionConditions(ConnectionConditions.login_required) @PathConditions(PathConditions.path_must_not_exists) @PathPermissions(PathPermissions.writable) - async def mkd(self, connection, rest): + async def mkd(self, connection: Connection, rest: str): real_path, virtual_path = self.get_paths(connection, rest) await connection.path_io.mkdir(real_path, parents=True) connection.response("257", "") return True @ConnectionConditions(ConnectionConditions.login_required) - @PathConditions( - PathConditions.path_must_exists, - PathConditions.path_must_be_dir) + @PathConditions(PathConditions.path_must_exists, PathConditions.path_must_be_dir) @PathPermissions(PathPermissions.writable) - async def rmd(self, connection, rest): + async def rmd(self, connection: Connection, rest: str): real_path, virtual_path = self.get_paths(connection, rest) await connection.path_io.rmdir(real_path) connection.response("250", "") return True @staticmethod - def _format_mlsx_time(local_seconds): + def _format_mlsx_time(local_seconds: float) -> str: return time.strftime("%Y%m%d%H%M%S", time.gmtime(local_seconds)) - def _build_mlsx_facts_from_stats(self, stats): + def _build_mlsx_facts_from_stats(self, stats: os.stat_result): return { "Size": stats.st_size, "Create": self._format_mlsx_time(stats.st_ctime), "Modify": self._format_mlsx_time(stats.st_mtime), } - async def build_mlsx_string(self, connection, path): + async def build_mlsx_string(self, connection: Connection, path: pathlib.Path): if not await connection.path_io.exists(path): facts = {} else: @@ -1170,18 +1133,19 @@ async def build_mlsx_string(self, connection, path): @ConnectionConditions( ConnectionConditions.login_required, - ConnectionConditions.passive_server_started) + ConnectionConditions.passive_server_started, + ) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.readable) - async def mlsd(self, connection, rest): - + async def mlsd(self, connection: Connection, rest: str): @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, fail_code="425", - fail_info="Can't open data connection") + fail_info="Can't open data connection", + ) @worker - async def mlsd_worker(self, connection, rest): + async def mlsd_worker(self, connection: Connection, rest: str): stream = connection.data_connection del connection.data_connection async with stream: @@ -1200,10 +1164,10 @@ async def mlsd_worker(self, connection, rest): return True @staticmethod - def build_list_mtime(st_mtime, now=None): + def build_list_mtime(st_mtime: float, now: float | None = None) -> str: if now is None: now = time.time() - mtime = time.localtime(st_mtime) + mtime: time.struct_time = time.localtime(st_mtime) with setlocale("C"): if now - HALF_OF_YEAR_IN_SECONDS < st_mtime <= now: s = time.strftime("%b %e %H:%M", mtime) @@ -1211,7 +1175,7 @@ def build_list_mtime(st_mtime, now=None): s = time.strftime("%b %e %Y", mtime) return s - async def build_list_string(self, connection, path): + async def build_list_string(self, connection: Connection, path: pathlib.Path): stats = await connection.path_io.stat(path) mtime = self.build_list_mtime(stats.st_mtime) fields = ( @@ -1221,25 +1185,26 @@ async def build_list_string(self, connection, path): "none", str(stats.st_size), mtime, - path.name + path.name, ) s = " ".join(fields) return s @ConnectionConditions( ConnectionConditions.login_required, - ConnectionConditions.passive_server_started) + ConnectionConditions.passive_server_started, + ) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.readable) - async def list(self, connection, rest): - + async def list(self, connection: Connection, rest: str): @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, fail_code="425", - fail_info="Can't open data connection") + fail_info="Can't open data connection", + ) @worker - async def list_worker(self, connection, rest): + async def list_worker(self, connection: Connection, rest: str): stream = connection.data_connection del connection.data_connection async with stream: @@ -1263,7 +1228,7 @@ async def list_worker(self, connection, rest): @ConnectionConditions(ConnectionConditions.login_required) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.readable) - async def mlst(self, connection, rest): + async def mlst(self, connection: Connection, rest: str): real_path, virtual_path = self.get_paths(connection, rest) s = await self.build_mlsx_string(connection, real_path) connection.response("250", ["start", s, "end"], True) @@ -1272,7 +1237,7 @@ async def mlst(self, connection, rest): @ConnectionConditions(ConnectionConditions.login_required) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.writable) - async def rnfr(self, connection, rest): + async def rnfr(self, connection: Connection, rest: str): real_path, virtual_path = self.get_paths(connection, rest) connection.rename_from = real_path connection.response("350", "rename from accepted") @@ -1280,10 +1245,11 @@ async def rnfr(self, connection, rest): @ConnectionConditions( ConnectionConditions.login_required, - ConnectionConditions.rename_from_required) + ConnectionConditions.rename_from_required, + ) @PathConditions(PathConditions.path_must_not_exists) @PathPermissions(PathPermissions.writable) - async def rnto(self, connection, rest): + async def rnto(self, connection: Connection, rest: str): real_path, virtual_path = self.get_paths(connection, rest) rename_from = connection.rename_from del connection.rename_from @@ -1292,29 +1258,39 @@ async def rnto(self, connection, rest): return True @ConnectionConditions(ConnectionConditions.login_required) - @PathConditions( - PathConditions.path_must_exists, - PathConditions.path_must_be_file) + @PathConditions(PathConditions.path_must_exists, PathConditions.path_must_be_file) @PathPermissions(PathPermissions.writable) - async def dele(self, connection, rest): + async def dele(self, connection: Connection, rest: str): real_path, virtual_path = self.get_paths(connection, rest) await connection.path_io.unlink(real_path) connection.response("250", "") return True + @ConnectionConditions(ConnectionConditions.login_required) + @PathConditions(PathConditions.path_must_exists, PathConditions.path_must_be_file) + async def size(self, connection: Connection, rest): + if connection.transfer_type == "A": + connection.response("550", "SIZE not allowed in ASCII mode") + return True + real_path, virtual_path = self.get_paths(connection, rest) + file_size = await connection.path_io.size(real_path) + connection.response("213", str(file_size)) + return True + @ConnectionConditions( ConnectionConditions.login_required, - ConnectionConditions.passive_server_started) + ConnectionConditions.passive_server_started, + ) @PathPermissions(PathPermissions.writable) - async def stor(self, connection, rest, mode="wb"): - + async def stor(self, connection: Connection, rest: str, mode="wb") -> bool: @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, fail_code="425", - fail_info="Can't open data connection") + fail_info="Can't open data connection", + ) @worker - async def stor_worker(self, connection, rest): + async def stor_worker(self, connection: Connection, rest: str): stream = connection.data_connection del connection.data_connection if connection.restart_offset: @@ -1343,20 +1319,19 @@ async def stor_worker(self, connection, rest): @ConnectionConditions( ConnectionConditions.login_required, - ConnectionConditions.passive_server_started) - @PathConditions( - PathConditions.path_must_exists, - PathConditions.path_must_be_file) + ConnectionConditions.passive_server_started, + ) + @PathConditions(PathConditions.path_must_exists, PathConditions.path_must_be_file) @PathPermissions(PathPermissions.readable) - async def retr(self, connection, rest): - + async def retr(self, connection: Connection, rest: str): @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, fail_code="425", - fail_info="Can't open data connection") + fail_info="Can't open data connection", + ) @worker - async def retr_worker(self, connection, rest): + async def retr_worker(self, connection: Connection, rest: str): stream = connection.data_connection del connection.data_connection file_in = connection.path_io.open(real_path, mode="rb") @@ -1376,7 +1351,7 @@ async def retr_worker(self, connection, rest): return True @ConnectionConditions(ConnectionConditions.login_required) - async def type(self, connection, rest): + async def type(self, connection: Connection, rest: str): if rest in ("I", "A"): connection.transfer_type = rest code, info = "200", "" @@ -1386,12 +1361,12 @@ async def type(self, connection, rest): return True @ConnectionConditions(ConnectionConditions.login_required) - async def pbsz(self, connection, rest): + async def pbsz(self, connection: Connection, rest: str): connection.response("200", "") return True @ConnectionConditions(ConnectionConditions.login_required) - async def prot(self, connection, rest): + async def prot(self, connection: Connection, rest): if rest == "P": code, info = "200", "" else: @@ -1399,7 +1374,11 @@ async def prot(self, connection, rest): connection.response(code, info) return True - async def _start_passive_server(self, connection, handler_callback): + async def _start_passive_server( + self, + connection: Connection, + handler_callback: Callable, + ) -> asyncio.base_events.Server: if self.available_data_ports is not None: viewed_ports = set() while True: @@ -1434,8 +1413,7 @@ async def _start_passive_server(self, connection, handler_callback): return passive_server @ConnectionConditions(ConnectionConditions.login_required) - async def pasv(self, connection, rest): - + async def pasv(self, connection: Connection, rest: str): async def handler(reader, writer): if connection.future.data_connection.done(): writer.close() @@ -1471,7 +1449,7 @@ async def handler(reader, writer): connection.response("503", ["this server started in ipv6 mode"]) return False - nums = tuple(map(int, host.split("."))) + (port >> 8, port & 0xff) + nums = tuple(map(int, host.split("."))) + (port >> 8, port & 0xFF) info.append(f"({','.join(map(str, nums))})") if connection.future.data_connection.done(): connection.data_connection.close() @@ -1480,8 +1458,7 @@ async def handler(reader, writer): return True @ConnectionConditions(ConnectionConditions.login_required) - async def epsv(self, connection, rest): - + async def epsv(self, connection: Connection, rest: str) -> bool: async def handler(reader, writer): if connection.future.data_connection.done(): writer.close() @@ -1521,7 +1498,7 @@ async def handler(reader, writer): return True @ConnectionConditions(ConnectionConditions.login_required) - async def abor(self, connection, rest): + async def abor(self, connection: Connection, rest: str) -> bool: if connection.extra_workers: for worker in connection.extra_workers: worker.cancel() @@ -1529,10 +1506,10 @@ async def abor(self, connection, rest): connection.response("226", "nothing to abort") return True - async def appe(self, connection, rest): + async def appe(self, connection: Connection, rest: str) -> bool: return await self.stor(connection, rest, "ab") - async def rest(self, connection, rest): + async def rest(self, connection: Connection, rest: str) -> bool: if rest.isdigit(): connection.restart_offset = int(rest) connection.response("350", f"restarting at {rest}") @@ -1542,7 +1519,7 @@ async def rest(self, connection, rest): connection.response("501", message) return True - async def syst(self, connection, rest): + async def syst(self, connection: Connection, rest: str) -> bool: """Return system type (always returns UNIX type: L8).""" connection.response("215", "UNIX Type: L8") return True diff --git a/docs/conf.py b/docs/conf.py index 1408b6a..a432f14 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,42 +22,42 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) import aioftp # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'alabaster', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "alabaster", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'aioftp' -copyright = '2016, pohmelie' -author = 'pohmelie' +project = "aioftp" +copyright = "2016, pohmelie" +author = "pohmelie" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -77,37 +77,37 @@ # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -117,26 +117,26 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - 'logo': 'logo.png', - 'description': 'ftp client/server for asyncio', - 'github_user': 'pohmelie', - 'github_repo': 'aioftp', - 'github_button': True, - 'github_banner': True, + "logo": "logo.png", + "description": "ftp client/server for asyncio", + "github_user": "pohmelie", + "github_repo": "aioftp", + "github_button": True, + "github_banner": True, # 'travis_button': True, - 'pre_bg': '#FFF6E5', - 'note_bg': '#E5ECD1', - 'note_border': '#BFCF8C', - 'body_text': '#482C0A', - 'sidebar_text': '#49443E', - 'sidebar_header': '#4B4032', - 'page_width': "90%", + "pre_bg": "#FFF6E5", + "note_bg": "#E5ECD1", + "note_border": "#BFCF8C", + "body_text": "#482C0A", + "sidebar_text": "#49443E", + "sidebar_header": "#4B4032", + "page_width": "90%", } # Add any paths that contain custom themes here, relative to this directory. @@ -144,14 +144,14 @@ # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -161,127 +161,120 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'aioftpdoc' +htmlhelp_basename = "aioftpdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', - -# Latex figure (float) alignment -#'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', + # Latex figure (float) alignment + #'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'aioftp.tex', 'aioftp Documentation', - 'pohmelie', 'manual'), + (master_doc, "aioftp.tex", "aioftp Documentation", "pohmelie", "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'aioftp', 'aioftp Documentation', - [author], 1) -] +man_pages = [(master_doc, "aioftp", "aioftp Documentation", [author], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -290,22 +283,28 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'aioftp', 'aioftp Documentation', - author, 'aioftp', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "aioftp", + "aioftp Documentation", + author, + "aioftp", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False intersphinx_mapping = { "python": ("https://docs.python.org/3", None), diff --git a/docs/index.rst b/docs/index.rst index 337c155..d7c6ab6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -71,12 +71,14 @@ pyftpdlib 1.5.2 Dependencies ------------ -- Python 3.7+ +- Python 3.11+ 0.13.0 is the last version which supports python 3.5.3+ 0.16.1 is the last version which supports python 3.6+ +0.21.4 is the last version which supports python 3.7-3.10 + License ------- diff --git a/ftpbench.py b/ftpbench.py index 309e051..b505196 100644 --- a/ftpbench.py +++ b/ftpbench.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright (C) 2007-2016 Giampaolo Rodola' . # Use of this source code is governed by MIT license that can be @@ -128,9 +128,11 @@ def hilite(string, ok=True, bold=False): def print_bench(what, value, unit=""): - s = "%s %s %-8s" % (hilite("%-50s" % what, ok=None, bold=0), - hilite("%8.2f" % value), - unit) + s = "%s %s %-8s" % ( + hilite("%-50s" % what, ok=None, bold=0), + hilite("%8.2f" % value), + unit, + ) if server_memory: s += "%s" % hilite(server_memory.pop()) print(s.strip()) @@ -205,7 +207,7 @@ def timethis(what): """ @contextlib.contextmanager def benchmark(): - timer = time.clock if sys.platform == "win32" else time.time + timer = time.perf_counter start = timer() yield stop = timer() @@ -291,13 +293,13 @@ def request_file(): stop_at += time.time() - a tot_bytes += len(chunk) - try: - while chunk: - chunk = conn.recv(BUFFER_LEN) - ftp.voidresp() - conn.close() - except (ftplib.error_temp, ftplib.error_perm): - pass + try: + while chunk: + chunk = conn.recv(BUFFER_LEN) + ftp.voidresp() + conn.close() + except (ftplib.error_temp, ftplib.error_perm): + pass else: ftp.voidcmd('TYPE I') with contextlib.closing(ftp.transfercmd("STOR " + TESTFN)) as conn: @@ -356,8 +358,11 @@ def bench_multi_connect(): def bench_multi_retr(clients): stor(clients[0]) - with timethis("%s concurrent clients (RETR %s file)" % ( - howmany, bytes2human(FILE_SIZE))): + with timethis( + "%s concurrent clients (RETR %s file)" % ( + howmany, bytes2human(FILE_SIZE), + ), + ): for ftp in clients: ftp.voidcmd('TYPE I') conn = ftp.transfercmd("RETR " + TESTFN) @@ -368,8 +373,11 @@ def bench_multi_retr(clients): ftp.voidresp() def bench_multi_stor(clients): - with timethis("%s concurrent clients (STOR %s file)" % ( - howmany, bytes2human(FILE_SIZE))): + with timethis( + "%s concurrent clients (STOR %s file)" % ( + howmany, bytes2human(FILE_SIZE), + ), + ): for ftp in clients: ftp.voidcmd('TYPE I') conn = ftp.transfercmd("STOR " + TESTFN) @@ -490,35 +498,55 @@ def main(): DEBUG USAGE = "%s -u USERNAME -p PASSWORD [-H] [-P] [-b] [-n] [-s] [-k] " \ "[-t] [-d] [-S]" % (os.path.basename(__file__)) - parser = optparse.OptionParser(usage=USAGE, - epilog=__doc__[__doc__.find('Example'):], - formatter=OptFormatter()) + parser = optparse.OptionParser( + usage=USAGE, + epilog=__doc__[__doc__.find('Example'):], + formatter=OptFormatter(), + ) parser.add_option('-u', '--user', dest='user', help='username') parser.add_option('-p', '--pass', dest='password', help='password') - parser.add_option('-H', '--host', dest='host', default=HOST, - help='hostname') - parser.add_option('-P', '--port', dest='port', default=PORT, help='port', - type=int) - parser.add_option('-b', '--benchmark', dest='benchmark', - default='transfer', - help="benchmark type ('transfer', 'download', 'upload', " - "'concurrence', 'all')") - parser.add_option('-n', '--clients', dest='clients', default=200, - type="int", - help="number of concurrent clients used by " - "'concurrence' benchmark") - parser.add_option('-s', '--filesize', dest='filesize', default="10M", - help="file size used by 'concurrence' benchmark " - "(e.g. '10M')") - parser.add_option('-k', '--pid', dest='pid', default=None, type="int", - help="the PID of the FTP server process, to track its " - "memory usage") - parser.add_option('-t', '--timeout', dest='timeout', - default=TIMEOUT, type="int", help="the socket timeout") - parser.add_option('-d', '--debug', action='store_true', dest='debug', - help="whether to print debugging info") - parser.add_option('-S', '--ssl', action='store_true', dest='ssl', - help="whether to use FTPS") + parser.add_option( + '-H', '--host', dest='host', default=HOST, + help='hostname', + ) + parser.add_option( + '-P', '--port', dest='port', default=PORT, help='port', + type=int, + ) + parser.add_option( + '-b', '--benchmark', dest='benchmark', + default='transfer', + help="benchmark type ('transfer', 'download', 'upload', " + "'concurrence', 'all')", + ) + parser.add_option( + '-n', '--clients', dest='clients', default=200, + type="int", + help="number of concurrent clients used by " + "'concurrence' benchmark", + ) + parser.add_option( + '-s', '--filesize', dest='filesize', default="10M", + help="file size used by 'concurrence' benchmark " + "(e.g. '10M')", + ) + parser.add_option( + '-k', '--pid', dest='pid', default=None, type="int", + help="the PID of the FTP server process, to track its " + "memory usage", + ) + parser.add_option( + '-t', '--timeout', dest='timeout', + default=TIMEOUT, type="int", help="the socket timeout", + ) + parser.add_option( + '-d', '--debug', action='store_true', dest='debug', + help="whether to print debugging info", + ) + parser.add_option( + '-S', '--ssl', action='store_true', dest='ssl', + help="whether to use FTPS", + ) options, args = parser.parse_args() if not options.user or not options.password: @@ -554,8 +582,11 @@ def main(): # start benchmark if SERVER_PROC is not None: register_memory() - print("(starting with %s of memory being used)" % ( - hilite(server_memory.pop()))) + print( + "(starting with %s of memory being used)" % ( + hilite(server_memory.pop()) + ), + ) if options.benchmark == 'download': stor() bench_retr() diff --git a/setup.cfg b/setup.cfg index fba792a..c257006 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,15 +16,22 @@ classifiers = [options] packages = aioftp -python_requires = >= 3.7 +python_requires = >= 3.11 [options.extras_require] socks = siosocks >= 0.2.0 -tests = +dev = + alabaster + async_timeout >= 4.0.0 + black + docutils < 0.18.0 + mypy + pre-commit pytest pytest-asyncio pytest-cov - trustme - async_timeout >= 4.0.0 + ruff siosocks + sphinx + trustme diff --git a/tests/conftest.py b/tests/conftest.py index 4c1b743..6746e31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,21 @@ -import ssl +import asyncio import collections import contextlib -import tempfile -import asyncio -import math -import time import functools +import math import socket +import ssl +import tempfile +import time from pathlib import Path import pytest import pytest_asyncio import trustme from async_timeout import timeout - -import aioftp from siosocks.io.asyncio import socks_server_handler +import aioftp # No ssl tests since https://bugs.python.org/issue36098 ca = trustme.CA() @@ -54,22 +53,25 @@ def _wrap_with_defaults(kwargs): @pytest.fixture(params=["127.0.0.1", "::1"]) def pair_factory(request): - class Factory: - - def __init__(self, client=None, server=None, *, - connected=True, logged=True, do_quit=True, - host=request.param, - server_factory=aioftp.Server, - client_factory=aioftp.Client): + def __init__( + self, + client=None, + server=None, + *, + connected=True, + logged=True, + do_quit=True, + host=request.param, + server_factory=aioftp.Server, + client_factory=aioftp.Client, + ): if client is None: client = Container() - self.client = client_factory(*client.args, - **_wrap_with_defaults(client.kwargs)) + self.client = client_factory(*client.args, **_wrap_with_defaults(client.kwargs)) if server is None: server = Container() - self.server = server_factory(*server.args, - **_wrap_with_defaults(server.kwargs)) + self.server = server_factory(*server.args, **_wrap_with_defaults(server.kwargs)) self.connected = connected self.logged = logged self.do_quit = do_quit @@ -90,8 +92,7 @@ async def make_client_files(self, *paths, size=None, atom=b"-"): size = aioftp.DEFAULT_BLOCK_SIZE * 3 data = atom * size for p in map(Path, paths): - await self.client.path_io.mkdir(p.parent, parents=True, - exist_ok=True) + await self.client.path_io.mkdir(p.parent, parents=True, exist_ok=True) async with self.client.path_io.open(p, mode="wb") as f: await f.write(data) @@ -119,8 +120,7 @@ async def __aenter__(self): await self.timeout.__aenter__() await self.server.start(host=self.host) if self.connected: - await self.client.connect(self.server.server_host, - self.server.server_port) + await self.client.connect(self.server.server_host, self.server.server_port) if self.logged: await self.client.login() return self @@ -145,11 +145,11 @@ def context(*codes): assert set(e.received_codes) == set(codes) else: raise RuntimeError("There was no exception") + return context -@pytest.fixture(params=[aioftp.MemoryPathIO, aioftp.PathIO, - aioftp.AsyncPathIO]) +@pytest.fixture(params=[aioftp.MemoryPathIO, aioftp.PathIO, aioftp.AsyncPathIO]) def path_io(request): return request.param() @@ -164,7 +164,6 @@ def temp_dir(path_io): class Sleep: - def __init__(self): self.delay = 0 self.first_sleep = None @@ -179,8 +178,7 @@ async def sleep(self, delay, result=None, **kwargs): def is_close(self, delay, *, rel_tol=0.05, abs_tol=0.5): ok = math.isclose(self.delay, delay, rel_tol=rel_tol, abs_tol=abs_tol) if not ok: - print(f"latest sleep: {self.delay}; expected delay: " - f"{delay}; rel: {rel_tol}") + print(f"latest sleep: {self.delay}; expected delay: " f"{delay}; rel: {rel_tol}") return ok @@ -192,8 +190,7 @@ def skip_sleep(monkeypatch): yield sleeper -@pytest_asyncio.fixture(params=[("127.0.0.1", socket.AF_INET), - ("::1", socket.AF_INET6)]) +@pytest_asyncio.fixture(params=[("127.0.0.1", socket.AF_INET), ("::1", socket.AF_INET6)]) async def socks(request, unused_tcp_port): handler = functools.partial( socks_server_handler, @@ -204,8 +201,7 @@ async def socks(request, unused_tcp_port): Socks = collections.namedtuple("Socks", "host port server") host, family = request.param port = unused_tcp_port - server = await asyncio.start_server(handler, host=host, port=port, - family=family) + server = await asyncio.start_server(handler, host=host, port=port, family=family) yield Socks(host, port, server) server.close() await server.wait_closed() diff --git a/tests/test_abort.py b/tests/test_abort.py index 8a18778..6ae6e24 100644 --- a/tests/test_abort.py +++ b/tests/test_abort.py @@ -18,7 +18,6 @@ async def test_abort_stor(pair_factory): class SlowReadMemoryPathIO(aioftp.MemoryPathIO): - async def read(self, *args, **kwargs): await asyncio.sleep(0.01) return await super().read(*args, **kwargs) @@ -40,8 +39,7 @@ async def test_abort_retr(pair_factory, Server): @pytest.mark.asyncio -async def test_abort_retr_no_wait(pair_factory, Server, - expect_codes_in_exception): +async def test_abort_retr_no_wait(pair_factory, Server, expect_codes_in_exception): s = Server(path_io_factory=SlowReadMemoryPathIO) async with pair_factory(None, s) as pair: await pair.make_server_files("test.txt") @@ -63,14 +61,11 @@ async def test_nothing_to_abort(pair_factory): class SlowListMemoryPathIO(aioftp.MemoryPathIO): - async def is_file(self, *a, **kw): return True def list(self, *args, **kwargs): - class Lister(aioftp.AbstractAsyncLister): - async def __anext__(cls): await asyncio.sleep(0.01) return pathlib.PurePath("/test.txt") @@ -82,6 +77,7 @@ class Stat: st_size = 0 st_mtime = 0 st_ctime = 0 + return Stat diff --git a/tests/test_client_side_socks.py b/tests/test_client_side_socks.py index f1a8b4a..85ddb6d 100644 --- a/tests/test_client_side_socks.py +++ b/tests/test_client_side_socks.py @@ -4,16 +4,14 @@ @pytest.mark.asyncio async def test_socks_success(pair_factory, Client, socks): - client = Client(socks_host=socks.host, socks_port=socks.port, - socks_version=5, username="foo", password="bar") + client = Client(socks_host=socks.host, socks_port=socks.port, socks_version=5, username="foo", password="bar") async with pair_factory(client): pass @pytest.mark.asyncio async def test_socks_fail(pair_factory, Client, socks): - client = Client(socks_host=socks.host, socks_port=socks.port, - socks_version=5, username="bar", password="bar") + client = Client(socks_host=socks.host, socks_port=socks.port, socks_version=5, username="bar", password="bar") with pytest.raises(SocksException): async with pair_factory(client): pass diff --git a/tests/test_connection.py b/tests/test_connection.py index ac14132..42bb773 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -44,10 +44,7 @@ async def test_type_success(pair_factory, expect_codes_in_exception): async def test_custom_passive_commands(pair_factory): async with pair_factory(host="127.0.0.1") as pair: pair.client._passive_commands = None - await pair.client.get_passive_connection( - "A", - commands=["pasv", "epsv"] - ) + await pair.client.get_passive_connection("A", commands=["pasv", "epsv"]) @pytest.mark.asyncio @@ -82,8 +79,7 @@ async def test_pasv_connection_ports_not_added(pair_factory): @pytest.mark.asyncio -async def test_pasv_connection_ports(pair_factory, Server, - unused_tcp_port_factory): +async def test_pasv_connection_ports(pair_factory, Server, unused_tcp_port_factory): ports = [unused_tcp_port_factory(), unused_tcp_port_factory()] async with pair_factory(None, Server(data_ports=ports)) as pair: r, w = await pair.client.get_passive_connection() @@ -99,8 +95,7 @@ async def test_data_ports_remains_empty(pair_factory, Server): @pytest.mark.asyncio -async def test_pasv_connection_port_reused(pair_factory, Server, - unused_tcp_port): +async def test_pasv_connection_port_reused(pair_factory, Server, unused_tcp_port): s = Server(data_ports=[unused_tcp_port]) async with pair_factory(None, s) as pair: r, w = await pair.client.get_passive_connection() @@ -111,8 +106,7 @@ async def test_pasv_connection_port_reused(pair_factory, Server, await pair.client.quit() pair.client.close() assert pair.server.available_data_ports.qsize() == 1 - await pair.client.connect(pair.server.server_host, - pair.server.server_port) + await pair.client.connect(pair.server.server_host, pair.server.server_port) await pair.client.login() r, w = await pair.client.get_passive_connection() host, port, *_ = w.transport.get_extra_info("peername") @@ -121,9 +115,7 @@ async def test_pasv_connection_port_reused(pair_factory, Server, @pytest.mark.asyncio -async def test_pasv_connection_pasv_forced_response_address(pair_factory, - Server, - unused_tcp_port): +async def test_pasv_connection_pasv_forced_response_address(pair_factory, Server, unused_tcp_port): def ipv4_used(): try: ipaddress.IPv4Address(pair.host) @@ -132,9 +124,9 @@ def ipv4_used(): return False async with pair_factory( - server=Server(ipv4_pasv_forced_response_address='127.0.0.2'), + server=Server(ipv4_pasv_forced_response_address="127.0.0.2"), ) as pair: - assert pair.server.ipv4_pasv_forced_response_address == '127.0.0.2' + assert pair.server.ipv4_pasv_forced_response_address == "127.0.0.2" if ipv4_used(): # The connection fails here because the server starts to listen for @@ -145,19 +137,17 @@ def ipv4_used(): # pair.server.ipv4_pasv_forced_response_address failed to know that # the server returned correct external IP with pytest.raises(OSError): - await pair.client.get_passive_connection(commands=['pasv']) + await pair.client.get_passive_connection(commands=["pasv"]) # With epsv the connection should open as that does not use the # external IPv4 address but just tells the client the port to connect # to - await pair.client.get_passive_connection(commands=['epsv']) + await pair.client.get_passive_connection(commands=["epsv"]) @pytest.mark.parametrize("method", ["epsv", "pasv"]) @pytest.mark.asyncio -async def test_pasv_connection_no_free_port(pair_factory, Server, - expect_codes_in_exception, - method): +async def test_pasv_connection_no_free_port(pair_factory, Server, expect_codes_in_exception, method): s = Server(data_ports=[]) async with pair_factory(None, s, do_quit=False, host="127.0.0.1") as pair: assert pair.server.available_data_ports.qsize() == 0 @@ -166,8 +156,7 @@ async def test_pasv_connection_no_free_port(pair_factory, Server, @pytest.mark.asyncio -async def test_pasv_connection_busy_port(pair_factory, Server, - unused_tcp_port_factory): +async def test_pasv_connection_busy_port(pair_factory, Server, unused_tcp_port_factory): ports = [unused_tcp_port_factory(), unused_tcp_port_factory()] async with pair_factory(None, Server(data_ports=ports)) as pair: conflicting_server = await asyncio.start_server( @@ -184,9 +173,7 @@ async def test_pasv_connection_busy_port(pair_factory, Server, @pytest.mark.asyncio -async def test_pasv_connection_busy_port2(pair_factory, Server, - unused_tcp_port_factory, - expect_codes_in_exception): +async def test_pasv_connection_busy_port2(pair_factory, Server, unused_tcp_port_factory, expect_codes_in_exception): ports = [unused_tcp_port_factory()] s = Server(data_ports=ports) async with pair_factory(None, s, do_quit=False) as pair: @@ -218,8 +205,7 @@ async def test_client_session_context_manager(pair_factory): @pytest.mark.asyncio -async def test_long_login_sequence_fail(pair_factory, - expect_codes_in_exception): +async def test_long_login_sequence_fail(pair_factory, expect_codes_in_exception): class CustomServer(aioftp.Server): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -237,8 +223,7 @@ async def acct(self, connection, rest): connection.response("333") return True - factory = pair_factory(logged=False, server_factory=CustomServer, - do_quit=False) + factory = pair_factory(logged=False, server_factory=CustomServer, do_quit=False) async with factory as pair: with expect_codes_in_exception("333"): await pair.client.login() @@ -249,6 +234,7 @@ async def test_bad_sublines_seq(pair_factory, expect_codes_in_exception): class CustomServer(aioftp.Server): async def write_response(self, stream, code, lines="", list=False): import functools + lines = aioftp.wrap_with_container(lines) write = functools.partial(self.write_line, stream) *body, tail = lines @@ -260,6 +246,5 @@ async def write_response(self, stream, code, lines="", list=False): factory = pair_factory(connected=False, server_factory=CustomServer) async with factory as pair: with expect_codes_in_exception("220"): - await pair.client.connect(pair.server.server_host, - pair.server.server_port) + await pair.client.connect(pair.server.server_host, pair.server.server_port) await pair.client.login() diff --git a/tests/test_corner_cases.py b/tests/test_corner_cases.py index 5f1c974..725ff40 100644 --- a/tests/test_corner_cases.py +++ b/tests/test_corner_cases.py @@ -14,6 +14,7 @@ def __init__(self, *args, **kwargs): async def custom(*args, **kwargs): raise RuntimeError("Test error") + factory = pair_factory(server_factory=CustomServer, do_quit=False) async with factory as pair: with pytest.raises(ConnectionResetError): @@ -56,8 +57,7 @@ async def test_epsv_extra_arg(pair_factory, expect_codes_in_exception): @pytest.mark.asyncio -async def test_bad_server_path_io(pair_factory, Server, - expect_codes_in_exception): +async def test_bad_server_path_io(pair_factory, Server, expect_codes_in_exception): class BadPathIO(aioftp.MemoryPathIO): async def is_file(*a, **kw): return False diff --git a/tests/test_directory_actions.py b/tests/test_directory_actions.py index 9899aff..8d49647 100644 --- a/tests/test_directory_actions.py +++ b/tests/test_directory_actions.py @@ -49,8 +49,7 @@ async def test_change_directory(pair_factory): @pytest.mark.asyncio -async def test_change_directory_not_exist(pair_factory, - expect_codes_in_exception): +async def test_change_directory_not_exist(pair_factory, expect_codes_in_exception): async with pair_factory() as pair: with expect_codes_in_exception("550"): await pair.client.change_directory("bar") @@ -79,9 +78,7 @@ async def test_rename_non_empty_directory(pair_factory): class FakeErrorPathIO(aioftp.MemoryPathIO): - def list(self, path): - class Lister(aioftp.AbstractAsyncLister): @aioftp.pathio.universal_exception async def __anext__(self): @@ -91,8 +88,7 @@ async def __anext__(self): @pytest.mark.asyncio -async def test_exception_in_list(pair_factory, Server, - expect_codes_in_exception): +async def test_exception_in_list(pair_factory, Server, expect_codes_in_exception): s = Server(path_io_factory=FakeErrorPathIO) async with pair_factory(None, s) as pair: with expect_codes_in_exception("451"): diff --git a/tests/test_file.py b/tests/test_file.py index eb81ac6..be705a0 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -1,5 +1,5 @@ -import math import datetime as dt +import math from pathlib import PurePosixPath import pytest @@ -201,11 +201,8 @@ async def test_download_file_write_into(pair_factory): @pytest.mark.asyncio -async def test_upload_file_os_error(pair_factory, Server, - expect_codes_in_exception): - +async def test_upload_file_os_error(pair_factory, Server, expect_codes_in_exception): class OsErrorPathIO(aioftp.MemoryPathIO): - @aioftp.pathio.universal_exception async def write(self, fout, data): raise OSError("test os error") @@ -218,8 +215,7 @@ async def write(self, fout, data): @pytest.mark.asyncio -async def test_upload_path_unreachable(pair_factory, - expect_codes_in_exception): +async def test_upload_path_unreachable(pair_factory, expect_codes_in_exception): async with pair_factory() as pair: with expect_codes_in_exception("550"): async with pair.client.upload_stream("foo/bar/foo") as stream: @@ -244,5 +240,28 @@ async def test_stat_mlst(pair_factory): assert info["type"] == "file" for fact in ("modify", "create"): received = dt.datetime.strptime(info[fact], "%Y%m%d%H%M%S") - assert math.isclose(now.timestamp(), received.timestamp(), - abs_tol=10) + assert math.isclose(now.timestamp(), received.timestamp(), abs_tol=10) + + +@pytest.mark.asyncio +async def test_size(pair_factory, expect_codes_in_exception): + async with pair_factory() as pair: + await pair.make_server_files("foo/bar", size=5000) + size = await pair.client.size("foo/bar") + assert size == 5000 + await pair.client.remove_file("foo/bar") + assert await pair.server_paths_exists("foo/bar") is False + with expect_codes_in_exception("550"): + await pair.client.size("foo/bar") + assert await pair.client.is_dir("foo") + with expect_codes_in_exception("550"): + await pair.client.size("foo") + + +@pytest.mark.asyncio +async def test_size_ascii_mode(pair_factory, expect_codes_in_exception): + async with pair_factory() as pair: + await pair.make_server_files("foo", size=5000) + await pair.client.get_passive_connection(conn_type="A") + with expect_codes_in_exception("550"): + await pair.client.size("foo") diff --git a/tests/test_list_fallback.py b/tests/test_list_fallback.py index 0fde577..a1b057b 100644 --- a/tests/test_list_fallback.py +++ b/tests/test_list_fallback.py @@ -1,6 +1,6 @@ +import contextlib import pathlib import textwrap -import contextlib import pytest @@ -53,7 +53,8 @@ async def test_client_list_override_invalid_raw_command(pair_factory): def test_client_list_windows(): - test_str = textwrap.dedent("""\ + test_str = textwrap.dedent( + """\ 11/4/2018 9:09 PM . 8/10/2018 1:02 PM .. 9/23/2018 2:16 PM bin @@ -70,7 +71,8 @@ def test_client_list_windows(): 10/29/2018 10:55 AM 219 win7.sh 6 files 75,978,506,648 bytes 3 directories 22,198,362,112 bytes free - """) + """, + ) test_str = test_str.strip().split("\n") entities = {} parse = aioftp.Client(encoding="utf-8").parse_list_line_windows @@ -79,8 +81,7 @@ def test_client_list_windows(): path, stat = parse(x.encode("utf-8")) entities[path] = stat dirs = ["bin", "Desktop", "dow", "Downloads", "msc", "opt"] - files = ["win10.img", "win10.iso", "win10.sh", "win7.img", - "win7.iso", "win7.sh"] + files = ["win10.img", "win10.iso", "win10.sh", "win7.img", "win7.iso", "win7.sh"] assert len(entities) == len(dirs + files) for d in dirs: p = pathlib.PurePosixPath(d) @@ -100,10 +101,12 @@ async def test_client_list_override_with_custom(pair_factory, Client): def parser(b): import pickle + return pickle.loads(bytes.fromhex(b.decode().rstrip("\r\n"))) async def builder(_, path): import pickle + return pickle.dumps((path, meta)).hex() async with pair_factory(Client(parse_list_line_custom=parser)) as pair: @@ -121,12 +124,14 @@ async def builder(_, path): async def test_client_list_override_with_custom_last(pair_factory, Client): meta = {"type": "file", "works": True} - def parser(b): + def parser(b: bytes): import pickle + return pickle.loads(bytes.fromhex(b.decode().rstrip("\r\n"))) async def builder(_, path): import pickle + return pickle.dumps((path, meta)).hex() client = Client( diff --git a/tests/test_login.py b/tests/test_login.py index 08afb10..4411d3d 100644 --- a/tests/test_login.py +++ b/tests/test_login.py @@ -37,8 +37,7 @@ async def test_login_with_login_and_password(pair_factory, Server): @pytest.mark.asyncio -async def test_login_with_login_and_password_no_such_user( - pair_factory, Server, expect_codes_in_exception): +async def test_login_with_login_and_password_no_such_user(pair_factory, Server, expect_codes_in_exception): s = Server([aioftp.User("foo", "bar")]) async with pair_factory(None, s, logged=False) as pair: with expect_codes_in_exception("530"): @@ -46,8 +45,7 @@ async def test_login_with_login_and_password_no_such_user( @pytest.mark.asyncio -async def test_login_with_login_and_password_bad_password( - pair_factory, Server, expect_codes_in_exception): +async def test_login_with_login_and_password_bad_password(pair_factory, Server, expect_codes_in_exception): s = Server([aioftp.User("foo", "bar")]) async with pair_factory(None, s, logged=False) as pair: with expect_codes_in_exception("530"): @@ -55,8 +53,7 @@ async def test_login_with_login_and_password_bad_password( @pytest.mark.asyncio -async def test_pass_after_login(pair_factory, Server, - expect_codes_in_exception): +async def test_pass_after_login(pair_factory, Server, expect_codes_in_exception): s = Server([aioftp.User("foo", "bar")]) async with pair_factory(None, s, logged=False) as pair: await pair.client.login("foo", "bar") diff --git a/tests/test_maximum_connections.py b/tests/test_maximum_connections.py index b23548b..fa0dbc3 100644 --- a/tests/test_maximum_connections.py +++ b/tests/test_maximum_connections.py @@ -7,8 +7,7 @@ @pytest.mark.asyncio async def test_multiply_connections_no_limits(pair_factory): - Client = functools.partial(aioftp.Client, - path_io_factory=aioftp.MemoryPathIO) + Client = functools.partial(aioftp.Client, path_io_factory=aioftp.MemoryPathIO) async with pair_factory() as pair: s = pair.server clients = [Client() for _ in range(4)] @@ -20,10 +19,8 @@ async def test_multiply_connections_no_limits(pair_factory): @pytest.mark.asyncio -async def test_multiply_connections_limited_error(pair_factory, Server, - expect_codes_in_exception): - Client = functools.partial(aioftp.Client, - path_io_factory=aioftp.MemoryPathIO) +async def test_multiply_connections_limited_error(pair_factory, Server, expect_codes_in_exception): + Client = functools.partial(aioftp.Client, path_io_factory=aioftp.MemoryPathIO) s = Server(maximum_connections=4) async with pair_factory(None, s) as pair: s = pair.server @@ -46,10 +43,8 @@ async def test_multiply_user_commands(pair_factory, Server): @pytest.mark.asyncio -async def test_multiply_connections_with_user_limited_error( - pair_factory, Server, expect_codes_in_exception): - Client = functools.partial(aioftp.Client, - path_io_factory=aioftp.MemoryPathIO) +async def test_multiply_connections_with_user_limited_error(pair_factory, Server, expect_codes_in_exception): + Client = functools.partial(aioftp.Client, path_io_factory=aioftp.MemoryPathIO) s = Server([aioftp.User("foo", maximum_connections=4)]) async with pair_factory(None, s, connected=False) as pair: s = pair.server @@ -65,10 +60,8 @@ async def test_multiply_connections_with_user_limited_error( @pytest.mark.asyncio -async def test_multiply_connections_relogin_balanced( - pair_factory, Server, expect_codes_in_exception): - Client = functools.partial(aioftp.Client, - path_io_factory=aioftp.MemoryPathIO) +async def test_multiply_connections_relogin_balanced(pair_factory, Server, expect_codes_in_exception): + Client = functools.partial(aioftp.Client, path_io_factory=aioftp.MemoryPathIO) s = Server(maximum_connections=4) async with pair_factory(None, s, connected=False) as pair: s = pair.server diff --git a/tests/test_passive.py b/tests/test_passive.py index f52de3a..b4a757b 100644 --- a/tests/test_passive.py +++ b/tests/test_passive.py @@ -14,8 +14,7 @@ async def test_client_fallback_to_pasv_at_list(pair_factory): @pytest.mark.asyncio -async def test_client_fail_fallback_to_pasv_at_list(pair_factory, - expect_codes_in_exception): +async def test_client_fail_fallback_to_pasv_at_list(pair_factory, expect_codes_in_exception): async with pair_factory(host="127.0.0.1") as pair: pair.server.commands_mapping["epsv"] = not_implemented with expect_codes_in_exception("502"): diff --git a/tests/test_pathio.py b/tests/test_pathio.py index a44adc0..913a2f3 100644 --- a/tests/test_pathio.py +++ b/tests/test_pathio.py @@ -90,6 +90,19 @@ async def test_list(path_io, temp_dir): assert set(paths) == {d, f} +@pytest.mark.asyncio +async def test_size(path_io, temp_dir): + p = temp_dir / "foo" + async with path_io.open(p, mode="wb"): + pass + assert await path_io.exists(p) + size = await path_io.size(p) + assert size >= 0 + await path_io.unlink(p) + with universal_exception_reason(FileNotFoundError): + await path_io.size(p) + + @pytest.mark.asyncio async def test_stat(path_io, temp_dir): stat = await path_io.stat(temp_dir) diff --git a/tests/test_permissions.py b/tests/test_permissions.py index ec143b6..4db9f40 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -4,11 +4,8 @@ @pytest.mark.asyncio -async def test_permission_denied(pair_factory, Server, - expect_codes_in_exception): - s = Server([ - aioftp.User(permissions=[aioftp.Permission(writable=False)]) - ]) +async def test_permission_denied(pair_factory, Server, expect_codes_in_exception): + s = Server([aioftp.User(permissions=[aioftp.Permission(writable=False)])]) async with pair_factory(None, s) as pair: with expect_codes_in_exception("550"): await pair.client.make_directory("foo") @@ -16,14 +13,16 @@ async def test_permission_denied(pair_factory, Server, @pytest.mark.asyncio async def test_permission_overriden(pair_factory, Server): - s = Server([ - aioftp.User( - permissions=[ - aioftp.Permission("/", writable=False), - aioftp.Permission("/foo"), - ] - ) - ]) + s = Server( + [ + aioftp.User( + permissions=[ + aioftp.Permission("/", writable=False), + aioftp.Permission("/foo"), + ], + ), + ], + ) async with pair_factory(None, s) as pair: await pair.client.make_directory("foo") await pair.client.remove_directory("foo") diff --git a/tests/test_restart.py b/tests/test_restart.py index 2fe3794..45fd85d 100644 --- a/tests/test_restart.py +++ b/tests/test_restart.py @@ -20,8 +20,7 @@ async def test_restart_stor_appe(pair_factory, offset, method): atom = b"foobar" name = "foo.txt" insert = b"123" - expect = atom[:offset] + b"\x00" * (offset - len(atom)) + insert + \ - atom[offset + len(insert):] + expect = atom[:offset] + b"\x00" * (offset - len(atom)) + insert + atom[offset + len(insert) :] await pair.make_server_files(name, size=1, atom=atom) stream_factory = getattr(pair.client, method) async with stream_factory(name, offset=offset) as stream: diff --git a/tests/test_simple_functions.py b/tests/test_simple_functions.py index 5fbc312..8b98296 100644 --- a/tests/test_simple_functions.py +++ b/tests/test_simple_functions.py @@ -1,7 +1,7 @@ -import pathlib import asyncio import datetime import itertools +import pathlib import pytest @@ -64,47 +64,24 @@ def _c_locale_time(d, format="%b %d %H:%M"): def test_parse_ls_date_of_leap_year(): def date_to_p(d): return d.strftime("%Y%m%d%H%M00") + p = aioftp.Client.parse_ls_date # Leap year date to test d = datetime.datetime(year=2000, month=2, day=29) current_and_expected_dates = ( # 2016 (leap) - ( - datetime.datetime(year=2016, month=2, day=29), - datetime.datetime(year=2016, month=2, day=29) - ), + (datetime.datetime(year=2016, month=2, day=29), datetime.datetime(year=2016, month=2, day=29)), # 2017 - ( - datetime.datetime(year=2017, month=2, day=28), - datetime.datetime(year=2016, month=2, day=29) - ), - ( - datetime.datetime(year=2017, month=3, day=1), - datetime.datetime(year=2016, month=2, day=29) - ), + (datetime.datetime(year=2017, month=2, day=28), datetime.datetime(year=2016, month=2, day=29)), + (datetime.datetime(year=2017, month=3, day=1), datetime.datetime(year=2016, month=2, day=29)), # 2018 - ( - datetime.datetime(year=2018, month=2, day=28), - datetime.datetime(year=2016, month=2, day=29) - ), - ( - datetime.datetime(year=2018, month=3, day=1), - datetime.datetime(year=2020, month=2, day=29) - ), + (datetime.datetime(year=2018, month=2, day=28), datetime.datetime(year=2016, month=2, day=29)), + (datetime.datetime(year=2018, month=3, day=1), datetime.datetime(year=2020, month=2, day=29)), # 2019 - ( - datetime.datetime(year=2019, month=2, day=28), - datetime.datetime(year=2020, month=2, day=29) - ), - ( - datetime.datetime(year=2019, month=3, day=1), - datetime.datetime(year=2020, month=2, day=29) - ), + (datetime.datetime(year=2019, month=2, day=28), datetime.datetime(year=2020, month=2, day=29)), + (datetime.datetime(year=2019, month=3, day=1), datetime.datetime(year=2020, month=2, day=29)), # 2020 (leap) - ( - datetime.datetime(year=2020, month=2, day=29), - datetime.datetime(year=2020, month=2, day=29) - ), + (datetime.datetime(year=2020, month=2, day=29), datetime.datetime(year=2020, month=2, day=29)), ) for now, expected in current_and_expected_dates: assert p(_c_locale_time(d), now=now) == date_to_p(expected) @@ -113,6 +90,7 @@ def date_to_p(d): def test_parse_ls_date_not_older_than_6_month_format(): def date_to_p(d): return d.strftime("%Y%m%d%H%M00") + p = aioftp.Client.parse_ls_date dates = ( datetime.datetime(year=2002, month=1, day=1), @@ -128,6 +106,7 @@ def date_to_p(d): def test_parse_ls_date_older_than_6_month_format(): def date_to_p(d): return d.strftime("%Y%m%d%H%M00") + p = aioftp.Client.parse_ls_date dates = ( datetime.datetime(year=2002, month=1, day=1), @@ -147,6 +126,7 @@ def date_to_p(d): def test_parse_ls_date_short(): def date_to_p(d): return d.strftime("%Y%m%d%H%M00") + p = aioftp.Client.parse_ls_date dates = ( datetime.datetime(year=2002, month=1, day=1), @@ -170,7 +150,7 @@ def test_parse_list_line_unix(): ], "unknown": [ "Erw-rw-r-- 1 poh poh 6595 Feb 27 04:14 history.rst", - ] + ], } p = aioftp.Client(encoding="utf-8").parse_list_line_unix for t, stack in lines.items(): diff --git a/tests/test_throttle.py b/tests/test_throttle.py index dcc262b..f90b4b3 100644 --- a/tests/test_throttle.py +++ b/tests/test_throttle.py @@ -20,8 +20,7 @@ async def test_patched_sleep(skip_sleep): @pytest.mark.parametrize("type", ["read", "write"]) @pytest.mark.parametrize("direction", ["download", "upload"]) @pytest.mark.asyncio -async def test_client_side_throttle(pair_factory, skip_sleep, times, type, - direction): +async def test_client_side_throttle(pair_factory, skip_sleep, times, type, direction): async with pair_factory() as pair: await pair.make_server_files("foo", size=SIZE) await pair.make_client_files("foo", size=SIZE) @@ -37,20 +36,24 @@ async def test_client_side_throttle(pair_factory, skip_sleep, times, type, @pytest.mark.parametrize("users", [1, 2, 3]) @pytest.mark.parametrize("throttle_direction", ["read", "write"]) @pytest.mark.parametrize("data_direction", ["download", "upload"]) -@pytest.mark.parametrize("throttle_level", ["throttle", - "throttle_per_connection"]) +@pytest.mark.parametrize("throttle_level", ["throttle", "throttle_per_connection"]) @pytest.mark.asyncio -async def test_server_side_throttle(pair_factory, skip_sleep, times, users, - throttle_direction, data_direction, - throttle_level): +async def test_server_side_throttle( + pair_factory, + skip_sleep, + times, + users, + throttle_direction, + data_direction, + throttle_level, +): async with pair_factory() as pair: names = [] for i in range(users): name = f"foo{i}" names.append(name) await pair.make_server_files(name, size=SIZE) - throttle = reduce(getattr, [throttle_level, throttle_direction], - pair.server) + throttle = reduce(getattr, [throttle_level, throttle_direction], pair.server) throttle.limit = SIZE / times clients = [] for name in names: