Apply typing to all of pre-commit-hooks

This commit is contained in:
Anthony Sottile 2019-01-31 19:19:10 -08:00
parent 63cc3414e9
commit 030bfac7e4
54 changed files with 401 additions and 264 deletions

9
.gitignore vendored
View File

@ -1,16 +1,11 @@
*.egg-info
*.iml
*.py[co]
.*.sw[a-z]
.pytest_cache
.coverage
.idea
.project
.pydevproject
.tox
.venv.touch
/.mypy_cache
/.pytest_cache
/venv*
coverage-html
dist
# SublimeText project/workspace files
*.sublime-*

View File

@ -27,7 +27,7 @@ repos:
rev: v1.3.5
hooks:
- id: reorder-python-imports
language_version: python2.7
language_version: python3
- repo: https://github.com/asottile/pyupgrade
rev: v1.11.1
hooks:
@ -36,3 +36,8 @@ repos:
rev: v0.7.1
hooks:
- id: add-trailing-comma
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.660
hooks:
- id: mypy
language_version: python3

View File

@ -1,3 +1,4 @@
dist: xenial
language: python
matrix:
include: # These should match the tox env list
@ -6,9 +7,8 @@ matrix:
python: 3.6
- env: TOXENV=py37
python: 3.7
dist: xenial
- env: TOXENV=pypy
python: pypy-5.7.1
python: pypy2.7-5.10.0
install: pip install coveralls tox
script: tox
before_install:

View File

@ -4,7 +4,9 @@ import io
import os.path
import shutil
import tarfile
from urllib.request import urlopen
import urllib.request
from typing import cast
from typing import IO
DOWNLOAD_PATH = (
'https://github.com/github/git-lfs/releases/download/'
@ -15,7 +17,7 @@ DEST_PATH = '/tmp/git-lfs/git-lfs'
DEST_DIR = os.path.dirname(DEST_PATH)
def main():
def main(): # type: () -> int
if (
os.path.exists(DEST_PATH) and
os.path.isfile(DEST_PATH) and
@ -27,12 +29,13 @@ def main():
shutil.rmtree(DEST_DIR, ignore_errors=True)
os.makedirs(DEST_DIR, exist_ok=True)
contents = io.BytesIO(urlopen(DOWNLOAD_PATH).read())
contents = io.BytesIO(urllib.request.urlopen(DOWNLOAD_PATH).read())
with tarfile.open(fileobj=contents) as tar:
with tar.extractfile(PATH_IN_TAR) as src_file:
with cast(IO[bytes], tar.extractfile(PATH_IN_TAR)) as src_file:
with open(DEST_PATH, 'wb') as dest_file:
shutil.copyfileobj(src_file, dest_file)
os.chmod(DEST_PATH, 0o755)
return 0
if __name__ == '__main__':

12
mypy.ini Normal file
View File

@ -0,0 +1,12 @@
[mypy]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true
[mypy-testing.*]
disallow_untyped_defs = false
[mypy-tests.*]
disallow_untyped_defs = false

View File

@ -3,7 +3,7 @@ from __future__ import print_function
from __future__ import unicode_literals
def main(argv=None):
def main(): # type: () -> int
raise SystemExit(
'autopep8-wrapper is deprecated. Instead use autopep8 directly via '
'https://github.com/pre-commit/mirrors-autopep8',

View File

@ -7,13 +7,17 @@ import argparse
import json
import math
import os
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import Set
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output
def lfs_files():
def lfs_files(): # type: () -> Set[str]
try:
# Introduced in git-lfs 2.2.0, first working in 2.2.1
lfs_ret = cmd_output('git', 'lfs', 'status', '--json')
@ -24,6 +28,7 @@ def lfs_files():
def find_large_added_files(filenames, maxkb):
# type: (Iterable[str], int) -> int
# Find all added files that are also in the list of files pre-commit tells
# us about
filenames = (added_files() & set(filenames)) - lfs_files()
@ -38,7 +43,7 @@ def find_large_added_files(filenames, maxkb):
return retv
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',

View File

@ -7,9 +7,11 @@ import ast
import platform
import sys
import traceback
from typing import Optional
from typing import Sequence
def check_ast(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
@ -34,4 +36,4 @@ def check_ast(argv=None):
if __name__ == '__main__':
exit(check_ast())
exit(main())

View File

@ -4,6 +4,10 @@ import argparse
import ast
import collections
import sys
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
BUILTIN_TYPES = {
@ -22,14 +26,17 @@ BuiltinTypeCall = collections.namedtuple('BuiltinTypeCall', ['name', 'line', 'co
class BuiltinTypeVisitor(ast.NodeVisitor):
def __init__(self, ignore=None, allow_dict_kwargs=True):
self.builtin_type_calls = []
# type: (Optional[Sequence[str]], bool) -> None
self.builtin_type_calls = [] # type: List[BuiltinTypeCall]
self.ignore = set(ignore) if ignore else set()
self.allow_dict_kwargs = allow_dict_kwargs
def _check_dict_call(self, node):
def _check_dict_call(self, node): # type: (ast.Call) -> bool
return self.allow_dict_kwargs and (getattr(node, 'kwargs', None) or getattr(node, 'keywords', None))
def visit_Call(self, node):
def visit_Call(self, node): # type: (ast.Call) -> None
if not isinstance(node.func, ast.Name):
# Ignore functions that are object attributes (`foo.bar()`).
# Assume that if the user calls `builtins.list()`, they know what
@ -47,6 +54,7 @@ class BuiltinTypeVisitor(ast.NodeVisitor):
def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_kwargs=True):
# type: (str, Optional[Sequence[str]], bool) -> List[BuiltinTypeCall]
with open(filename, 'rb') as f:
tree = ast.parse(f.read(), filename=filename)
visitor = BuiltinTypeVisitor(ignore=ignore, allow_dict_kwargs=allow_dict_kwargs)
@ -54,24 +62,22 @@ def check_file_for_builtin_type_constructors(filename, ignore=None, allow_dict_k
return visitor.builtin_type_calls
def parse_args(argv):
def parse_ignore(value):
return set(value.split(','))
def parse_ignore(value): # type: (str) -> Set[str]
return set(value.split(','))
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--ignore', type=parse_ignore, default=set())
allow_dict_kwargs = parser.add_mutually_exclusive_group(required=False)
allow_dict_kwargs.add_argument('--allow-dict-kwargs', action='store_true')
allow_dict_kwargs.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false')
allow_dict_kwargs.set_defaults(allow_dict_kwargs=True)
mutex = parser.add_mutually_exclusive_group(required=False)
mutex.add_argument('--allow-dict-kwargs', action='store_true')
mutex.add_argument('--no-allow-dict-kwargs', dest='allow_dict_kwargs', action='store_false')
mutex.set_defaults(allow_dict_kwargs=True)
return parser.parse_args(argv)
args = parser.parse_args(argv)
def main(argv=None):
args = parse_args(argv)
rc = 0
for filename in args.filenames:
calls = check_file_for_builtin_type_constructors(

View File

@ -3,9 +3,11 @@ from __future__ import print_function
from __future__ import unicode_literals
import argparse
from typing import Optional
from typing import Sequence
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)

View File

@ -3,16 +3,20 @@ from __future__ import print_function
from __future__ import unicode_literals
import argparse
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import Set
from pre_commit_hooks.util import added_files
from pre_commit_hooks.util import cmd_output
def lower_set(iterable):
def lower_set(iterable): # type: (Iterable[str]) -> Set[str]
return {x.lower() for x in iterable}
def find_conflicting_filenames(filenames):
def find_conflicting_filenames(filenames): # type: (Sequence[str]) -> int
repo_files = set(cmd_output('git', 'ls-files').splitlines())
relevant_files = set(filenames) | added_files()
repo_files -= relevant_files
@ -41,7 +45,7 @@ def find_conflicting_filenames(filenames):
return retv
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'filenames', nargs='*',

View File

@ -5,6 +5,8 @@ from __future__ import unicode_literals
import argparse
import io
import tokenize
from typing import Optional
from typing import Sequence
NON_CODE_TOKENS = frozenset((
@ -13,6 +15,7 @@ NON_CODE_TOKENS = frozenset((
def check_docstring_first(src, filename='<unknown>'):
# type: (str, str) -> int
"""Returns nonzero if the source has what looks like a docstring that is
not at the beginning of the source.
@ -50,7 +53,7 @@ def check_docstring_first(src, filename='<unknown>'):
return 0
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)

View File

@ -6,9 +6,11 @@ from __future__ import unicode_literals
import argparse
import pipes
import sys
from typing import Optional
from typing import Sequence
def check_has_shebang(path):
def check_has_shebang(path): # type: (str) -> int
with open(path, 'rb') as f:
first_bytes = f.read(2)
@ -27,7 +29,7 @@ def check_has_shebang(path):
return 0
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)
@ -38,3 +40,7 @@ def main(argv=None):
retv |= check_has_shebang(filename)
return retv
if __name__ == '__main__':
exit(main())

View File

@ -4,9 +4,11 @@ import argparse
import io
import json
import sys
from typing import Optional
from typing import Sequence
def check_json(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='JSON filenames to check.')
args = parser.parse_args(argv)
@ -22,4 +24,4 @@ def check_json(argv=None):
if __name__ == '__main__':
sys.exit(check_json())
sys.exit(main())

View File

@ -2,6 +2,9 @@ from __future__ import print_function
import argparse
import os.path
from typing import Optional
from typing import Sequence
CONFLICT_PATTERNS = [
b'<<<<<<< ',
@ -12,7 +15,7 @@ CONFLICT_PATTERNS = [
WARNING_MSG = 'Merge conflict string "{0}" found in {1}:{2}'
def is_in_merge():
def is_in_merge(): # type: () -> int
return (
os.path.exists(os.path.join('.git', 'MERGE_MSG')) and
(
@ -23,7 +26,7 @@ def is_in_merge():
)
def detect_merge_conflict(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument('--assume-in-merge', action='store_true')
@ -47,4 +50,4 @@ def detect_merge_conflict(argv=None):
if __name__ == '__main__':
exit(detect_merge_conflict())
exit(main())

View File

@ -4,9 +4,11 @@ from __future__ import unicode_literals
import argparse
import os.path
from typing import Optional
from typing import Sequence
def check_symlinks(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser(description='Checks for broken symlinks.')
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
@ -25,4 +27,4 @@ def check_symlinks(argv=None):
if __name__ == '__main__':
exit(check_symlinks())
exit(main())

View File

@ -5,6 +5,8 @@ from __future__ import unicode_literals
import argparse
import re
import sys
from typing import Optional
from typing import Sequence
GITHUB_NON_PERMALINK = re.compile(
@ -12,7 +14,7 @@ GITHUB_NON_PERMALINK = re.compile(
)
def _check_filename(filename):
def _check_filename(filename): # type: (str) -> int
retv = 0
with open(filename, 'rb') as f:
for i, line in enumerate(f, 1):
@ -24,7 +26,7 @@ def _check_filename(filename):
return retv
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args(argv)

View File

@ -5,10 +5,12 @@ from __future__ import unicode_literals
import argparse
import io
import sys
import xml.sax
import xml.sax.handler
from typing import Optional
from typing import Sequence
def check_xml(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='XML filenames to check.')
args = parser.parse_args(argv)
@ -17,7 +19,7 @@ def check_xml(argv=None):
for filename in args.filenames:
try:
with io.open(filename, 'rb') as xml_file:
xml.sax.parse(xml_file, xml.sax.ContentHandler())
xml.sax.parse(xml_file, xml.sax.handler.ContentHandler())
except xml.sax.SAXException as exc:
print('{}: Failed to xml parse ({})'.format(filename, exc))
retval = 1
@ -25,4 +27,4 @@ def check_xml(argv=None):
if __name__ == '__main__':
sys.exit(check_xml())
sys.exit(main())

View File

@ -3,22 +3,26 @@ from __future__ import print_function
import argparse
import collections
import sys
from typing import Any
from typing import Generator
from typing import Optional
from typing import Sequence
import ruamel.yaml
yaml = ruamel.yaml.YAML(typ='safe')
def _exhaust(gen):
def _exhaust(gen): # type: (Generator[str, None, None]) -> None
for _ in gen:
pass
def _parse_unsafe(*args, **kwargs):
def _parse_unsafe(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.parse(*args, **kwargs))
def _load_all(*args, **kwargs):
def _load_all(*args, **kwargs): # type: (*Any, **Any) -> None
_exhaust(yaml.load_all(*args, **kwargs))
@ -31,7 +35,7 @@ LOAD_FNS = {
}
def check_yaml(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-m', '--multi', '--allow-multiple-documents', action='store_true',
@ -63,4 +67,4 @@ def check_yaml(argv=None):
if __name__ == '__main__':
sys.exit(check_yaml())
sys.exit(main())

View File

@ -5,6 +5,9 @@ import argparse
import ast
import collections
import traceback
from typing import List
from typing import Optional
from typing import Sequence
DEBUG_STATEMENTS = {'pdb', 'ipdb', 'pudb', 'q', 'rdb'}
@ -12,21 +15,21 @@ Debug = collections.namedtuple('Debug', ('line', 'col', 'name', 'reason'))
class DebugStatementParser(ast.NodeVisitor):
def __init__(self):
self.breakpoints = []
def __init__(self): # type: () -> None
self.breakpoints = [] # type: List[Debug]
def visit_Import(self, node):
def visit_Import(self, node): # type: (ast.Import) -> None
for name in node.names:
if name.name in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, name.name, 'imported')
self.breakpoints.append(st)
def visit_ImportFrom(self, node):
def visit_ImportFrom(self, node): # type: (ast.ImportFrom) -> None
if node.module in DEBUG_STATEMENTS:
st = Debug(node.lineno, node.col_offset, node.module, 'imported')
self.breakpoints.append(st)
def visit_Call(self, node):
def visit_Call(self, node): # type: (ast.Call) -> None
"""python3.7+ breakpoint()"""
if isinstance(node.func, ast.Name) and node.func.id == 'breakpoint':
st = Debug(node.lineno, node.col_offset, node.func.id, 'called')
@ -34,7 +37,7 @@ class DebugStatementParser(ast.NodeVisitor):
self.generic_visit(node)
def check_file(filename):
def check_file(filename): # type: (str) -> int
try:
with open(filename, 'rb') as f:
ast_obj = ast.parse(f.read(), filename=filename)
@ -58,7 +61,7 @@ def check_file(filename):
return int(bool(visitor.breakpoints))
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to run')
args = parser.parse_args(argv)

View File

@ -3,11 +3,16 @@ from __future__ import unicode_literals
import argparse
import os
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from six.moves import configparser
def get_aws_credential_files_from_env():
def get_aws_credential_files_from_env(): # type: () -> Set[str]
"""Extract credential file paths from environment variables."""
files = set()
for env_var in (
@ -19,7 +24,7 @@ def get_aws_credential_files_from_env():
return files
def get_aws_secrets_from_env():
def get_aws_secrets_from_env(): # type: () -> Set[str]
"""Extract AWS secrets from environment variables."""
keys = set()
for env_var in (
@ -30,7 +35,7 @@ def get_aws_secrets_from_env():
return keys
def get_aws_secrets_from_file(credentials_file):
def get_aws_secrets_from_file(credentials_file): # type: (str) -> Set[str]
"""Extract AWS secrets from configuration files.
Read an ini-style configuration file and return a set with all found AWS
@ -62,6 +67,7 @@ def get_aws_secrets_from_file(credentials_file):
def check_file_for_aws_keys(filenames, keys):
# type: (Sequence[str], Set[str]) -> List[Dict[str, str]]
"""Check if files contain AWS secrets.
Return a list of all files containing AWS secrets and keys found, with all
@ -82,7 +88,7 @@ def check_file_for_aws_keys(filenames, keys):
return bad_files
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Filenames to run')
parser.add_argument(
@ -111,7 +117,7 @@ def main(argv=None):
# of files to to gather AWS secrets from.
credential_files |= get_aws_credential_files_from_env()
keys = set()
keys = set() # type: Set[str]
for credential_file in credential_files:
keys |= get_aws_secrets_from_file(credential_file)

View File

@ -2,6 +2,8 @@ from __future__ import print_function
import argparse
import sys
from typing import Optional
from typing import Sequence
BLACKLIST = [
b'BEGIN RSA PRIVATE KEY',
@ -15,7 +17,7 @@ BLACKLIST = [
]
def detect_private_key(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to check')
args = parser.parse_args(argv)
@ -37,4 +39,4 @@ def detect_private_key(argv=None):
if __name__ == '__main__':
sys.exit(detect_private_key())
sys.exit(main())

View File

@ -4,9 +4,12 @@ from __future__ import unicode_literals
import argparse
import os
import sys
from typing import IO
from typing import Optional
from typing import Sequence
def fix_file(file_obj):
def fix_file(file_obj): # type: (IO[bytes]) -> int
# Test for newline at end of file
# Empty files will throw IOError here
try:
@ -49,7 +52,7 @@ def fix_file(file_obj):
return 0
def end_of_file_fixer(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@ -68,4 +71,4 @@ def end_of_file_fixer(argv=None):
if __name__ == '__main__':
sys.exit(end_of_file_fixer())
sys.exit(main())

View File

@ -12,12 +12,15 @@ conflicts and keep the file nicely ordered.
from __future__ import print_function
import argparse
from typing import IO
from typing import Optional
from typing import Sequence
PASS = 0
FAIL = 1
def sort_file_contents(f):
def sort_file_contents(f): # type: (IO[bytes]) -> int
before = list(f)
after = sorted([line.strip(b'\n\r') for line in before if line.strip()])
@ -33,7 +36,7 @@ def sort_file_contents(f):
return FAIL
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='+', help='Files to sort')
args = parser.parse_args(argv)

View File

@ -4,11 +4,15 @@ from __future__ import unicode_literals
import argparse
import collections
from typing import IO
from typing import Optional
from typing import Sequence
from typing import Union
DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-\n'
def has_coding(line):
def has_coding(line): # type: (bytes) -> bool
if not line.strip():
return False
return (
@ -33,15 +37,16 @@ class ExpectedContents(collections.namedtuple(
__slots__ = ()
@property
def has_any_pragma(self):
def has_any_pragma(self): # type: () -> bool
return self.pragma_status is not False
def is_expected_pragma(self, remove):
def is_expected_pragma(self, remove): # type: (bool) -> bool
expected_pragma_status = not remove
return self.pragma_status is expected_pragma_status
def _get_expected_contents(first_line, second_line, rest, expected_pragma):
# type: (bytes, bytes, bytes, bytes) -> ExpectedContents
if first_line.startswith(b'#!'):
shebang = first_line
potential_coding = second_line
@ -51,7 +56,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
rest = second_line + rest
if potential_coding == expected_pragma:
pragma_status = True
pragma_status = True # type: Optional[bool]
elif has_coding(potential_coding):
pragma_status = None
else:
@ -64,6 +69,7 @@ def _get_expected_contents(first_line, second_line, rest, expected_pragma):
def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
# type: (IO[bytes], bool, bytes) -> int
expected = _get_expected_contents(
f.readline(), f.readline(), f.read(), expected_pragma,
)
@ -93,17 +99,17 @@ def fix_encoding_pragma(f, remove=False, expected_pragma=DEFAULT_PRAGMA):
return 1
def _normalize_pragma(pragma):
def _normalize_pragma(pragma): # type: (Union[bytes, str]) -> bytes
if not isinstance(pragma, bytes):
pragma = pragma.encode('UTF-8')
return pragma.rstrip() + b'\n'
def _to_disp(pragma):
def _to_disp(pragma): # type: (bytes) -> str
return pragma.decode().rstrip()
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser('Fixes the encoding pragma of python files')
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
parser.add_argument(

View File

@ -2,10 +2,13 @@ from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
from typing import Optional
from typing import Sequence
from pre_commit_hooks.util import cmd_output
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
# `argv` is ignored, pre-commit will send us a list of files that we
# don't care about
added_diff = cmd_output(

View File

@ -4,6 +4,9 @@ from __future__ import unicode_literals
import argparse
import collections
from typing import Dict
from typing import Optional
from typing import Sequence
CRLF = b'\r\n'
@ -14,7 +17,7 @@ ALL_ENDINGS = (CR, CRLF, LF)
FIX_TO_LINE_ENDING = {'cr': CR, 'crlf': CRLF, 'lf': LF}
def _fix(filename, contents, ending):
def _fix(filename, contents, ending): # type: (str, bytes, bytes) -> None
new_contents = b''.join(
line.rstrip(b'\r\n') + ending for line in contents.splitlines(True)
)
@ -22,11 +25,11 @@ def _fix(filename, contents, ending):
f.write(new_contents)
def fix_filename(filename, fix):
def fix_filename(filename, fix): # type: (str, str) -> int
with open(filename, 'rb') as f:
contents = f.read()
counts = collections.defaultdict(int)
counts = collections.defaultdict(int) # type: Dict[bytes, int]
for line in contents.splitlines(True):
for ending in ALL_ENDINGS:
@ -63,7 +66,7 @@ def fix_filename(filename, fix):
return other_endings
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-f', '--fix',

View File

@ -1,12 +1,15 @@
from __future__ import print_function
import argparse
from typing import Optional
from typing import Sequence
from typing import Set
from pre_commit_hooks.util import CalledProcessError
from pre_commit_hooks.util import cmd_output
def is_on_branch(protected):
def is_on_branch(protected): # type: (Set[str]) -> bool
try:
branch = cmd_output('git', 'symbolic-ref', 'HEAD')
except CalledProcessError:
@ -15,7 +18,7 @@ def is_on_branch(protected):
return '/'.join(chunks[2:]) in protected
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'-b', '--branch', action='append',

View File

@ -5,12 +5,20 @@ import io
import json
import sys
from collections import OrderedDict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from six import text_type
def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=[]):
def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_keys=()):
# type: (str, str, bool, bool, Sequence[str]) -> str
def pairs_first(pairs):
# type: (Sequence[Tuple[str, str]]) -> Mapping[str, str]
before = [pair for pair in pairs if pair[0] in top_keys]
before = sorted(before, key=lambda x: top_keys.index(x[0]))
after = [pair for pair in pairs if pair[0] not in top_keys]
@ -27,13 +35,13 @@ def _get_pretty_format(contents, indent, ensure_ascii=True, sort_keys=True, top_
return text_type(json_pretty) + '\n'
def _autofix(filename, new_contents):
def _autofix(filename, new_contents): # type: (str, str) -> None
print('Fixing file {}'.format(filename))
with io.open(filename, 'w', encoding='UTF-8') as f:
f.write(new_contents)
def parse_num_to_int(s):
def parse_num_to_int(s): # type: (str) -> Union[int, str]
"""Convert string numbers to int, leaving strings as is."""
try:
return int(s)
@ -41,11 +49,11 @@ def parse_num_to_int(s):
return s
def parse_topkeys(s):
def parse_topkeys(s): # type: (str) -> List[str]
return s.split(',')
def pretty_format_json(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'--autofix',
@ -117,4 +125,4 @@ def pretty_format_json(argv=None):
if __name__ == '__main__':
sys.exit(pretty_format_json())
sys.exit(main())

View File

@ -1,6 +1,10 @@
from __future__ import print_function
import argparse
from typing import IO
from typing import List
from typing import Optional
from typing import Sequence
PASS = 0
@ -9,21 +13,23 @@ FAIL = 1
class Requirement(object):
def __init__(self):
def __init__(self): # type: () -> None
super(Requirement, self).__init__()
self.value = None
self.comments = []
self.value = None # type: Optional[bytes]
self.comments = [] # type: List[bytes]
@property
def name(self):
def name(self): # type: () -> bytes
assert self.value is not None, self.value
if self.value.startswith(b'-e '):
return self.value.lower().partition(b'=')[-1]
return self.value.lower().partition(b'==')[0]
def __lt__(self, requirement):
def __lt__(self, requirement): # type: (Requirement) -> int
# \n means top of file comment, so always return True,
# otherwise just do a string comparison with value.
assert self.value is not None, self.value
if self.value == b'\n':
return True
elif requirement.value == b'\n':
@ -32,10 +38,10 @@ class Requirement(object):
return self.name < requirement.name
def fix_requirements(f):
requirements = []
def fix_requirements(f): # type: (IO[bytes]) -> int
requirements = [] # type: List[Requirement]
before = tuple(f)
after = []
after = [] # type: List[bytes]
before_string = b''.join(before)
@ -46,6 +52,7 @@ def fix_requirements(f):
for line in before:
# If the most recent requirement object has a value, then it's
# time to start building the next requirement object.
if not len(requirements) or requirements[-1].value is not None:
requirements.append(Requirement())
@ -78,6 +85,7 @@ def fix_requirements(f):
for requirement in sorted(requirements):
after.extend(requirement.comments)
assert requirement.value, requirement.value
after.append(requirement.value)
after.extend(rest)
@ -92,7 +100,7 @@ def fix_requirements(f):
return FAIL
def fix_requirements_txt(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@ -109,3 +117,7 @@ def fix_requirements_txt(argv=None):
retv |= ret_for_file
return retv
if __name__ == '__main__':
exit(main())

View File

@ -21,12 +21,15 @@ complicated YAML files.
from __future__ import print_function
import argparse
from typing import List
from typing import Optional
from typing import Sequence
QUOTES = ["'", '"']
def sort(lines):
def sort(lines): # type: (List[str]) -> List[str]
"""Sort a YAML file in alphabetical order, keeping blocks together.
:param lines: array of strings (without newlines)
@ -44,7 +47,7 @@ def sort(lines):
return new_lines
def parse_block(lines, header=False):
def parse_block(lines, header=False): # type: (List[str], bool) -> List[str]
"""Parse and return a single block, popping off the start of `lines`.
If parsing a header block, we stop after we reach a line that is not a
@ -60,7 +63,7 @@ def parse_block(lines, header=False):
return block_lines
def parse_blocks(lines):
def parse_blocks(lines): # type: (List[str]) -> List[List[str]]
"""Parse and return all possible blocks, popping off the start of `lines`.
:param lines: list of lines
@ -77,7 +80,7 @@ def parse_blocks(lines):
return blocks
def first_key(lines):
def first_key(lines): # type: (List[str]) -> str
"""Returns a string representing the sort key of a block.
The sort key is the first YAML key we encounter, ignoring comments, and
@ -95,9 +98,11 @@ def first_key(lines):
if any(line.startswith(quote) for quote in QUOTES):
return line[1:]
return line
else:
return '' # not actually reached in reality
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)

View File

@ -4,34 +4,39 @@ from __future__ import unicode_literals
import argparse
import io
import re
import tokenize
from typing import List
from typing import Optional
from typing import Sequence
START_QUOTE_RE = re.compile('^[a-zA-Z]*"')
double_quote_starts = tuple(s for s in tokenize.single_quoted if '"' in s)
def handle_match(token_text):
def handle_match(token_text): # type: (str) -> str
if '"""' in token_text or "'''" in token_text:
return token_text
for double_quote_start in double_quote_starts:
if token_text.startswith(double_quote_start):
meat = token_text[len(double_quote_start):-1]
if '"' in meat or "'" in meat:
break
return double_quote_start.replace('"', "'") + meat + "'"
return token_text
match = START_QUOTE_RE.match(token_text)
if match is not None:
meat = token_text[match.end():-1]
if '"' in meat or "'" in meat:
return token_text
else:
return match.group().replace('"', "'") + meat + "'"
else:
return token_text
def get_line_offsets_by_line_no(src):
def get_line_offsets_by_line_no(src): # type: (str) -> List[int]
# Padded so we can index with line number
offsets = [None, 0]
offsets = [-1, 0]
for line in src.splitlines():
offsets.append(offsets[-1] + len(line) + 1)
return offsets
def fix_strings(filename):
def fix_strings(filename): # type: (str) -> int
with io.open(filename, encoding='UTF-8') as f:
contents = f.read()
line_offsets = get_line_offsets_by_line_no(contents)
@ -60,7 +65,7 @@ def fix_strings(filename):
return 0
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
args = parser.parse_args(argv)
@ -74,3 +79,7 @@ def main(argv=None):
retv |= return_value
return retv
if __name__ == '__main__':
exit(main())

View File

@ -1,12 +1,14 @@
from __future__ import print_function
import argparse
import os.path
import re
import sys
from os.path import basename
from typing import Optional
from typing import Sequence
def validate_files(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
parser.add_argument(
@ -18,7 +20,7 @@ def validate_files(argv=None):
retcode = 0
test_name_pattern = 'test.*.py' if args.django else '.*_test.py'
for filename in args.filenames:
base = basename(filename)
base = os.path.basename(filename)
if (
not re.match(test_name_pattern, base) and
not base == '__init__.py' and
@ -35,4 +37,4 @@ def validate_files(argv=None):
if __name__ == '__main__':
sys.exit(validate_files())
sys.exit(main())

View File

@ -3,9 +3,11 @@ from __future__ import print_function
import argparse
import os
import sys
from typing import Optional
from typing import Sequence
def _fix_file(filename, is_markdown):
def _fix_file(filename, is_markdown): # type: (str, bool) -> bool
with open(filename, mode='rb') as file_processed:
lines = file_processed.readlines()
newlines = [_process_line(line, is_markdown) for line in lines]
@ -18,7 +20,7 @@ def _fix_file(filename, is_markdown):
return False
def _process_line(line, is_markdown):
def _process_line(line, is_markdown): # type: (bytes, bool) -> bytes
if line[-2:] == b'\r\n':
eol = b'\r\n'
elif line[-1:] == b'\n':
@ -31,7 +33,7 @@ def _process_line(line, is_markdown):
return line.rstrip() + eol
def main(argv=None):
def main(argv=None): # type: (Optional[Sequence[str]]) -> int
parser = argparse.ArgumentParser()
parser.add_argument(
'--no-markdown-linebreak-ext',

View File

@ -3,23 +3,25 @@ from __future__ import print_function
from __future__ import unicode_literals
import subprocess
from typing import Any
from typing import Set
class CalledProcessError(RuntimeError):
pass
def added_files():
def added_files(): # type: () -> Set[str]
return set(cmd_output(
'git', 'diff', '--staged', '--name-only', '--diff-filter=A',
).splitlines())
def cmd_output(*cmd, **kwargs):
def cmd_output(*cmd, **kwargs): # type: (*str, **Any) -> str
retcode = kwargs.pop('retcode', 0)
popen_kwargs = {'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE}
popen_kwargs.update(kwargs)
proc = subprocess.Popen(cmd, **popen_kwargs)
kwargs.setdefault('stdout', subprocess.PIPE)
kwargs.setdefault('stderr', subprocess.PIPE)
proc = subprocess.Popen(cmd, **kwargs)
stdout, stderr = proc.communicate()
stdout = stdout.decode('UTF-8')
if stderr is not None:

View File

@ -28,35 +28,36 @@ setup(
'ruamel.yaml>=0.15',
'six',
],
extras_require={':python_version<"3.5"': ['typing']},
entry_points={
'console_scripts': [
'autopep8-wrapper = pre_commit_hooks.autopep8_wrapper:main',
'check-added-large-files = pre_commit_hooks.check_added_large_files:main',
'check-ast = pre_commit_hooks.check_ast:check_ast',
'check-ast = pre_commit_hooks.check_ast:main',
'check-builtin-literals = pre_commit_hooks.check_builtin_literals:main',
'check-byte-order-marker = pre_commit_hooks.check_byte_order_marker:main',
'check-case-conflict = pre_commit_hooks.check_case_conflict:main',
'check-docstring-first = pre_commit_hooks.check_docstring_first:main',
'check-executables-have-shebangs = pre_commit_hooks.check_executables_have_shebangs:main',
'check-json = pre_commit_hooks.check_json:check_json',
'check-merge-conflict = pre_commit_hooks.check_merge_conflict:detect_merge_conflict',
'check-symlinks = pre_commit_hooks.check_symlinks:check_symlinks',
'check-json = pre_commit_hooks.check_json:main',
'check-merge-conflict = pre_commit_hooks.check_merge_conflict:main',
'check-symlinks = pre_commit_hooks.check_symlinks:main',
'check-vcs-permalinks = pre_commit_hooks.check_vcs_permalinks:main',
'check-xml = pre_commit_hooks.check_xml:check_xml',
'check-yaml = pre_commit_hooks.check_yaml:check_yaml',
'check-xml = pre_commit_hooks.check_xml:main',
'check-yaml = pre_commit_hooks.check_yaml:main',
'debug-statement-hook = pre_commit_hooks.debug_statement_hook:main',
'detect-aws-credentials = pre_commit_hooks.detect_aws_credentials:main',
'detect-private-key = pre_commit_hooks.detect_private_key:detect_private_key',
'detect-private-key = pre_commit_hooks.detect_private_key:main',
'double-quote-string-fixer = pre_commit_hooks.string_fixer:main',
'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:end_of_file_fixer',
'end-of-file-fixer = pre_commit_hooks.end_of_file_fixer:main',
'file-contents-sorter = pre_commit_hooks.file_contents_sorter:main',
'fix-encoding-pragma = pre_commit_hooks.fix_encoding_pragma:main',
'forbid-new-submodules = pre_commit_hooks.forbid_new_submodules:main',
'mixed-line-ending = pre_commit_hooks.mixed_line_ending:main',
'name-tests-test = pre_commit_hooks.tests_should_end_in_test:validate_files',
'name-tests-test = pre_commit_hooks.tests_should_end_in_test:main',
'no-commit-to-branch = pre_commit_hooks.no_commit_to_branch:main',
'pretty-format-json = pre_commit_hooks.pretty_format_json:pretty_format_json',
'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:fix_requirements_txt',
'pretty-format-json = pre_commit_hooks.pretty_format_json:main',
'requirements-txt-fixer = pre_commit_hooks.requirements_txt_fixer:main',
'sort-simple-yaml = pre_commit_hooks.sort_simple_yaml:main',
'trailing-whitespace-fixer = pre_commit_hooks.trailing_whitespace_fixer:main',
],

0
testing/resources/bad_json_latin1.nonjson Executable file → Normal file
View File

View File

@ -1,17 +0,0 @@
from six.moves import builtins
c1 = complex()
d1 = dict()
f1 = float()
i1 = int()
l1 = list()
s1 = str()
t1 = tuple()
c2 = builtins.complex()
d2 = builtins.dict()
f2 = builtins.float()
i2 = builtins.int()
l2 = builtins.list()
s2 = builtins.str()
t2 = builtins.tuple()

View File

@ -1,7 +0,0 @@
c1 = 0j
d1 = {}
f1 = 0.0
i1 = 0
l1 = []
s1 = ''
t1 = ()

View File

@ -1,15 +1,15 @@
from __future__ import absolute_import
from __future__ import unicode_literals
from pre_commit_hooks.check_ast import check_ast
from pre_commit_hooks.check_ast import main
from testing.util import get_resource_path
def test_failing_file():
ret = check_ast([get_resource_path('cannot_parse_ast.notpy')])
ret = main([get_resource_path('cannot_parse_ast.notpy')])
assert ret == 1
def test_passing_file():
ret = check_ast([__file__])
ret = main([__file__])
assert ret == 0

View File

@ -5,7 +5,35 @@ import pytest
from pre_commit_hooks.check_builtin_literals import BuiltinTypeCall
from pre_commit_hooks.check_builtin_literals import BuiltinTypeVisitor
from pre_commit_hooks.check_builtin_literals import main
from testing.util import get_resource_path
BUILTIN_CONSTRUCTORS = '''\
from six.moves import builtins
c1 = complex()
d1 = dict()
f1 = float()
i1 = int()
l1 = list()
s1 = str()
t1 = tuple()
c2 = builtins.complex()
d2 = builtins.dict()
f2 = builtins.float()
i2 = builtins.int()
l2 = builtins.list()
s2 = builtins.str()
t2 = builtins.tuple()
'''
BUILTIN_LITERALS = '''\
c1 = 0j
d1 = {}
f1 = 0.0
i1 = 0
l1 = []
s1 = ''
t1 = ()
'''
@pytest.fixture
@ -94,24 +122,26 @@ def test_dict_no_allow_kwargs_exprs(expression, calls):
def test_ignore_constructors():
visitor = BuiltinTypeVisitor(ignore=('complex', 'dict', 'float', 'int', 'list', 'str', 'tuple'))
with open(get_resource_path('builtin_constructors.py'), 'rb') as f:
visitor.visit(ast.parse(f.read(), 'builtin_constructors.py'))
visitor.visit(ast.parse(BUILTIN_CONSTRUCTORS))
assert visitor.builtin_type_calls == []
def test_failing_file():
rc = main([get_resource_path('builtin_constructors.py')])
def test_failing_file(tmpdir):
f = tmpdir.join('f.py')
f.write(BUILTIN_CONSTRUCTORS)
rc = main([f.strpath])
assert rc == 1
def test_passing_file():
rc = main([get_resource_path('builtin_literals.py')])
def test_passing_file(tmpdir):
f = tmpdir.join('f.py')
f.write(BUILTIN_LITERALS)
rc = main([f.strpath])
assert rc == 0
def test_failing_file_ignore_all():
rc = main([
'--ignore=complex,dict,float,int,list,str,tuple',
get_resource_path('builtin_constructors.py'),
])
def test_failing_file_ignore_all(tmpdir):
f = tmpdir.join('f.py')
f.write(BUILTIN_CONSTRUCTORS)
rc = main(['--ignore=complex,dict,float,int,list,str,tuple', f.strpath])
assert rc == 0

View File

@ -1,6 +1,6 @@
import pytest
from pre_commit_hooks.check_json import check_json
from pre_commit_hooks.check_json import main
from testing.util import get_resource_path
@ -11,8 +11,8 @@ from testing.util import get_resource_path
('ok_json.json', 0),
),
)
def test_check_json(capsys, filename, expected_retval):
ret = check_json([get_resource_path(filename)])
def test_main(capsys, filename, expected_retval):
ret = main([get_resource_path(filename)])
assert ret == expected_retval
if expected_retval == 1:
stdout, _ = capsys.readouterr()

View File

@ -6,7 +6,7 @@ import shutil
import pytest
from pre_commit_hooks.check_merge_conflict import detect_merge_conflict
from pre_commit_hooks.check_merge_conflict import main
from pre_commit_hooks.util import cmd_output
from testing.util import get_resource_path
@ -102,7 +102,7 @@ def repository_pending_merge(tmpdir):
@pytest.mark.usefixtures('f1_is_a_conflict_file')
def test_merge_conflicts_git():
assert detect_merge_conflict(['f1']) == 1
assert main(['f1']) == 1
@pytest.mark.parametrize(
@ -110,7 +110,7 @@ def test_merge_conflicts_git():
)
def test_merge_conflicts_failing(contents, repository_pending_merge):
repository_pending_merge.join('f2').write_binary(contents)
assert detect_merge_conflict(['f2']) == 1
assert main(['f2']) == 1
@pytest.mark.parametrize(
@ -118,22 +118,22 @@ def test_merge_conflicts_failing(contents, repository_pending_merge):
)
def test_merge_conflicts_ok(contents, f1_is_a_conflict_file):
f1_is_a_conflict_file.join('f1').write_binary(contents)
assert detect_merge_conflict(['f1']) == 0
assert main(['f1']) == 0
@pytest.mark.usefixtures('f1_is_a_conflict_file')
def test_ignores_binary_files():
shutil.copy(get_resource_path('img1.jpg'), 'f1')
assert detect_merge_conflict(['f1']) == 0
assert main(['f1']) == 0
def test_does_not_care_when_not_in_a_merge(tmpdir):
f = tmpdir.join('README.md')
f.write_binary(b'problem\n=======\n')
assert detect_merge_conflict([str(f.realpath())]) == 0
assert main([str(f.realpath())]) == 0
def test_care_when_assumed_merge(tmpdir):
f = tmpdir.join('README.md')
f.write_binary(b'problem\n=======\n')
assert detect_merge_conflict([str(f.realpath()), '--assume-in-merge']) == 1
assert main([str(f.realpath()), '--assume-in-merge']) == 1

View File

@ -2,7 +2,7 @@ import os
import pytest
from pre_commit_hooks.check_symlinks import check_symlinks
from pre_commit_hooks.check_symlinks import main
xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support')
@ -12,12 +12,12 @@ xfail_symlink = pytest.mark.xfail(os.name == 'nt', reason='No symlink support')
@pytest.mark.parametrize(
('dest', 'expected'), (('exists', 0), ('does-not-exist', 1)),
)
def test_check_symlinks(tmpdir, dest, expected): # pragma: no cover (symlinks)
def test_main(tmpdir, dest, expected): # pragma: no cover (symlinks)
tmpdir.join('exists').ensure()
symlink = tmpdir.join('symlink')
symlink.mksymlinkto(tmpdir.join(dest))
assert check_symlinks((symlink.strpath,)) == expected
assert main((symlink.strpath,)) == expected
def test_check_symlinks_normal_file(tmpdir):
assert check_symlinks((tmpdir.join('f').ensure().strpath,)) == 0
def test_main_normal_file(tmpdir):
assert main((tmpdir.join('f').ensure().strpath,)) == 0

View File

@ -1,6 +1,6 @@
import pytest
from pre_commit_hooks.check_xml import check_xml
from pre_commit_hooks.check_xml import main
from testing.util import get_resource_path
@ -10,6 +10,6 @@ from testing.util import get_resource_path
('ok_xml.xml', 0),
),
)
def test_check_xml(filename, expected_retval):
ret = check_xml([get_resource_path(filename)])
def test_main(filename, expected_retval):
ret = main([get_resource_path(filename)])
assert ret == expected_retval

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import pytest
from pre_commit_hooks.check_yaml import check_yaml
from pre_commit_hooks.check_yaml import main
from testing.util import get_resource_path
@ -13,29 +13,29 @@ from testing.util import get_resource_path
('ok_yaml.yaml', 0),
),
)
def test_check_yaml(filename, expected_retval):
ret = check_yaml([get_resource_path(filename)])
def test_main(filename, expected_retval):
ret = main([get_resource_path(filename)])
assert ret == expected_retval
def test_check_yaml_allow_multiple_documents(tmpdir):
def test_main_allow_multiple_documents(tmpdir):
f = tmpdir.join('test.yaml')
f.write('---\nfoo\n---\nbar\n')
# should fail without the setting
assert check_yaml((f.strpath,))
assert main((f.strpath,))
# should pass when we allow multiple documents
assert not check_yaml(('--allow-multiple-documents', f.strpath))
assert not main(('--allow-multiple-documents', f.strpath))
def test_fails_even_with_allow_multiple_documents(tmpdir):
f = tmpdir.join('test.yaml')
f.write('[')
assert check_yaml(('--allow-multiple-documents', f.strpath))
assert main(('--allow-multiple-documents', f.strpath))
def test_check_yaml_unsafe(tmpdir):
def test_main_unsafe(tmpdir):
f = tmpdir.join('test.yaml')
f.write(
'some_foo: !vault |\n'
@ -43,12 +43,12 @@ def test_check_yaml_unsafe(tmpdir):
' deadbeefdeadbeefdeadbeef\n',
)
# should fail "safe" check
assert check_yaml((f.strpath,))
assert main((f.strpath,))
# should pass when we allow unsafe documents
assert not check_yaml(('--unsafe', f.strpath))
assert not main(('--unsafe', f.strpath))
def test_check_yaml_unsafe_still_fails_on_syntax_errors(tmpdir):
def test_main_unsafe_still_fails_on_syntax_errors(tmpdir):
f = tmpdir.join('test.yaml')
f.write('[')
assert check_yaml(('--unsafe', f.strpath))
assert main(('--unsafe', f.strpath))

View File

@ -1,6 +1,6 @@
import pytest
from pre_commit_hooks.detect_private_key import detect_private_key
from pre_commit_hooks.detect_private_key import main
# Input, expected return value
TESTS = (
@ -18,7 +18,7 @@ TESTS = (
@pytest.mark.parametrize(('input_s', 'expected_retval'), TESTS)
def test_detect_private_key(input_s, expected_retval, tmpdir):
def test_main(input_s, expected_retval, tmpdir):
path = tmpdir.join('file.txt')
path.write_binary(input_s)
assert detect_private_key([path.strpath]) == expected_retval
assert main([path.strpath]) == expected_retval

View File

@ -2,8 +2,8 @@ import io
import pytest
from pre_commit_hooks.end_of_file_fixer import end_of_file_fixer
from pre_commit_hooks.end_of_file_fixer import fix_file
from pre_commit_hooks.end_of_file_fixer import main
# Input, expected return value, expected output
@ -35,7 +35,7 @@ def test_integration(input_s, expected_retval, output, tmpdir):
path = tmpdir.join('file.txt')
path.write_binary(input_s)
ret = end_of_file_fixer([path.strpath])
ret = main([path.strpath])
file_output = path.read_binary()
assert file_output == output

View File

@ -11,24 +11,24 @@ from pre_commit_hooks.util import cmd_output
def test_other_branch(temp_git_dir):
with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'anotherbranch')
assert is_on_branch(('master',)) is False
assert is_on_branch({'master'}) is False
def test_multi_branch(temp_git_dir):
with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'another/branch')
assert is_on_branch(('master',)) is False
assert is_on_branch({'master'}) is False
def test_multi_branch_fail(temp_git_dir):
with temp_git_dir.as_cwd():
cmd_output('git', 'checkout', '-b', 'another/branch')
assert is_on_branch(('another/branch',)) is True
assert is_on_branch({'another/branch'}) is True
def test_master_branch(temp_git_dir):
with temp_git_dir.as_cwd():
assert is_on_branch(('master',)) is True
assert is_on_branch({'master'}) is True
def test_main_branch_call(temp_git_dir):

View File

@ -3,8 +3,8 @@ import shutil
import pytest
from six import PY2
from pre_commit_hooks.pretty_format_json import main
from pre_commit_hooks.pretty_format_json import parse_num_to_int
from pre_commit_hooks.pretty_format_json import pretty_format_json
from testing.util import get_resource_path
@ -23,8 +23,8 @@ def test_parse_num_to_int():
('pretty_formatted_json.json', 0),
),
)
def test_pretty_format_json(filename, expected_retval):
ret = pretty_format_json([get_resource_path(filename)])
def test_main(filename, expected_retval):
ret = main([get_resource_path(filename)])
assert ret == expected_retval
@ -36,8 +36,8 @@ def test_pretty_format_json(filename, expected_retval):
('pretty_formatted_json.json', 0),
),
)
def test_unsorted_pretty_format_json(filename, expected_retval):
ret = pretty_format_json(['--no-sort-keys', get_resource_path(filename)])
def test_unsorted_main(filename, expected_retval):
ret = main(['--no-sort-keys', get_resource_path(filename)])
assert ret == expected_retval
@ -51,17 +51,17 @@ def test_unsorted_pretty_format_json(filename, expected_retval):
('tab_pretty_formatted_json.json', 0),
),
)
def test_tab_pretty_format_json(filename, expected_retval): # pragma: no cover
ret = pretty_format_json(['--indent', '\t', get_resource_path(filename)])
def test_tab_main(filename, expected_retval): # pragma: no cover
ret = main(['--indent', '\t', get_resource_path(filename)])
assert ret == expected_retval
def test_non_ascii_pretty_format_json():
ret = pretty_format_json(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')])
def test_non_ascii_main():
ret = main(['--no-ensure-ascii', get_resource_path('non_ascii_pretty_formatted_json.json')])
assert ret == 0
def test_autofix_pretty_format_json(tmpdir):
def test_autofix_main(tmpdir):
srcfile = tmpdir.join('to_be_json_formatted.json')
shutil.copyfile(
get_resource_path('not_pretty_formatted_json.json'),
@ -69,30 +69,30 @@ def test_autofix_pretty_format_json(tmpdir):
)
# now launch the autofix on that file
ret = pretty_format_json(['--autofix', srcfile.strpath])
ret = main(['--autofix', srcfile.strpath])
# it should have formatted it
assert ret == 1
# file was formatted (shouldn't trigger linter again)
ret = pretty_format_json([srcfile.strpath])
ret = main([srcfile.strpath])
assert ret == 0
def test_orderfile_get_pretty_format():
ret = pretty_format_json(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')])
ret = main(['--top-keys=alist', get_resource_path('pretty_formatted_json.json')])
assert ret == 0
def test_not_orderfile_get_pretty_format():
ret = pretty_format_json(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')])
ret = main(['--top-keys=blah', get_resource_path('pretty_formatted_json.json')])
assert ret == 1
def test_top_sorted_get_pretty_format():
ret = pretty_format_json(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')])
ret = main(['--top-keys=01-alist,alist', get_resource_path('top_sorted_json.json')])
assert ret == 0
def test_badfile_pretty_format_json():
ret = pretty_format_json([get_resource_path('ok_yaml.yaml')])
def test_badfile_main():
ret = main([get_resource_path('ok_yaml.yaml')])
assert ret == 1

View File

@ -1,7 +1,7 @@
import pytest
from pre_commit_hooks.requirements_txt_fixer import FAIL
from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt
from pre_commit_hooks.requirements_txt_fixer import main
from pre_commit_hooks.requirements_txt_fixer import PASS
from pre_commit_hooks.requirements_txt_fixer import Requirement
@ -36,7 +36,7 @@ def test_integration(input_s, expected_retval, output, tmpdir):
path = tmpdir.join('file.txt')
path.write_binary(input_s)
output_retval = fix_requirements_txt([path.strpath])
output_retval = main([path.strpath])
assert path.read_binary() == output
assert output_retval == expected_retval
@ -44,7 +44,7 @@ def test_integration(input_s, expected_retval, output, tmpdir):
def test_requirement_object():
top_of_file = Requirement()
top_of_file.comments.append('#foo')
top_of_file.comments.append(b'#foo')
top_of_file.value = b'\n'
requirement_foo = Requirement()

View File

@ -110,9 +110,9 @@ def test_first_key():
lines = ['# some comment', '"a": 42', 'b: 17', '', 'c: 19']
assert first_key(lines) == 'a": 42'
# no lines
# no lines (not a real situation)
lines = []
assert first_key(lines) is None
assert first_key(lines) == ''
@pytest.mark.parametrize('bad_lines,good_lines,_', TEST_SORTS)

View File

@ -1,36 +1,36 @@
from pre_commit_hooks.tests_should_end_in_test import validate_files
from pre_commit_hooks.tests_should_end_in_test import main
def test_validate_files_all_pass():
ret = validate_files(['foo_test.py', 'bar_test.py'])
def test_main_all_pass():
ret = main(['foo_test.py', 'bar_test.py'])
assert ret == 0
def test_validate_files_one_fails():
ret = validate_files(['not_test_ending.py', 'foo_test.py'])
def test_main_one_fails():
ret = main(['not_test_ending.py', 'foo_test.py'])
assert ret == 1
def test_validate_files_django_all_pass():
ret = validate_files(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py'])
def test_main_django_all_pass():
ret = main(['--django', 'tests.py', 'test_foo.py', 'test_bar.py', 'tests/test_baz.py'])
assert ret == 0
def test_validate_files_django_one_fails():
ret = validate_files(['--django', 'not_test_ending.py', 'test_foo.py'])
def test_main_django_one_fails():
ret = main(['--django', 'not_test_ending.py', 'test_foo.py'])
assert ret == 1
def test_validate_nested_files_django_one_fails():
ret = validate_files(['--django', 'tests/not_test_ending.py', 'test_foo.py'])
ret = main(['--django', 'tests/not_test_ending.py', 'test_foo.py'])
assert ret == 1
def test_validate_files_not_django_fails():
ret = validate_files(['foo_test.py', 'bar_test.py', 'test_baz.py'])
def test_main_not_django_fails():
ret = main(['foo_test.py', 'bar_test.py', 'test_baz.py'])
assert ret == 1
def test_validate_files_django_fails():
ret = validate_files(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py'])
def test_main_django_fails():
ret = main(['--django', 'foo_test.py', 'test_bar.py', 'test_baz.py'])
assert ret == 1

View File

@ -1,6 +1,6 @@
[tox]
# These should match the travis env list
envlist = py27,py36,py37,pypy
envlist = py27,py36,py37,pypy3
[testenv]
deps = -rrequirements-dev.txt