test.py 8.43 KB
import os
import operator
import sys
import contextlib
import itertools
from distutils.errors import DistutilsOptionError
from unittest import TestLoader

from setuptools.extern import six
from setuptools.extern.six.moves import map, filter

from pkg_resources import (resource_listdir, resource_exists, normalize_path,
                           working_set, _namespace_packages,
                           add_activation_listener, require, EntryPoint)
from setuptools import Command
from setuptools.py31compat import unittest_main


class ScanningLoader(TestLoader):
    def loadTestsFromModule(self, module, pattern=None):
        """Return a suite of all tests cases contained in the given module

        If the module is a package, load tests from all the modules in it.
        If the module has an ``additional_tests`` function, call it and add
        the return value to the tests.
        """
        tests = []
        tests.append(TestLoader.loadTestsFromModule(self, module))

        if hasattr(module, "additional_tests"):
            tests.append(module.additional_tests())

        if hasattr(module, '__path__'):
            for file in resource_listdir(module.__name__, ''):
                if file.endswith('.py') and file != '__init__.py':
                    submodule = module.__name__ + '.' + file[:-3]
                else:
                    if resource_exists(module.__name__, file + '/__init__.py'):
                        submodule = module.__name__ + '.' + file
                    else:
                        continue
                tests.append(self.loadTestsFromName(submodule))

        if len(tests) != 1:
            return self.suiteClass(tests)
        else:
            return tests[0]  # don't create a nested suite for only one return


# adapted from jaraco.classes.properties:NonDataProperty
class NonDataProperty(object):
    def __init__(self, fget):
        self.fget = fget

    def __get__(self, obj, objtype=None):
        if obj is None:
            return self
        return self.fget(obj)


class test(Command):
    """Command to run unit tests after in-place build"""

    description = "run unit tests after in-place build"

    user_options = [
        ('test-module=', 'm', "Run 'test_suite' in specified module"),
        ('test-suite=', 's',
         "Test suite to run (e.g. 'some_module.test_suite')"),
        ('test-runner=', 'r', "Test runner to use"),
    ]

    def initialize_options(self):
        self.test_suite = None
        self.test_module = None
        self.test_loader = None
        self.test_runner = None

    def finalize_options(self):

        if self.test_suite and self.test_module:
            msg = "You may specify a module or a suite, but not both"
            raise DistutilsOptionError(msg)

        if self.test_suite is None:
            if self.test_module is None:
                self.test_suite = self.distribution.test_suite
            else:
                self.test_suite = self.test_module + ".test_suite"

        if self.test_loader is None:
            self.test_loader = getattr(self.distribution, 'test_loader', None)
        if self.test_loader is None:
            self.test_loader = "setuptools.command.test:ScanningLoader"
        if self.test_runner is None:
            self.test_runner = getattr(self.distribution, 'test_runner', None)

    @NonDataProperty
    def test_args(self):
        return list(self._test_args())

    def _test_args(self):
        if self.verbose:
            yield '--verbose'
        if self.test_suite:
            yield self.test_suite

    def with_project_on_sys_path(self, func):
        """
        Backward compatibility for project_on_sys_path context.
        """
        with self.project_on_sys_path():
            func()

    @contextlib.contextmanager
    def project_on_sys_path(self, include_dists=[]):
        with_2to3 = six.PY3 and getattr(self.distribution, 'use_2to3', False)

        if with_2to3:
            # If we run 2to3 we can not do this inplace:

            # Ensure metadata is up-to-date
            self.reinitialize_command('build_py', inplace=0)
            self.run_command('build_py')
            bpy_cmd = self.get_finalized_command("build_py")
            build_path = normalize_path(bpy_cmd.build_lib)

            # Build extensions
            self.reinitialize_command('egg_info', egg_base=build_path)
            self.run_command('egg_info')

            self.reinitialize_command('build_ext', inplace=0)
            self.run_command('build_ext')
        else:
            # Without 2to3 inplace works fine:
            self.run_command('egg_info')

            # Build extensions in-place
            self.reinitialize_command('build_ext', inplace=1)
            self.run_command('build_ext')

        ei_cmd = self.get_finalized_command("egg_info")

        old_path = sys.path[:]
        old_modules = sys.modules.copy()

        try:
            project_path = normalize_path(ei_cmd.egg_base)
            sys.path.insert(0, project_path)
            working_set.__init__()
            add_activation_listener(lambda dist: dist.activate())
            require('%s==%s' % (ei_cmd.egg_name, ei_cmd.egg_version))
            with self.paths_on_pythonpath([project_path]):
                yield
        finally:
            sys.path[:] = old_path
            sys.modules.clear()
            sys.modules.update(old_modules)
            working_set.__init__()

    @staticmethod
    @contextlib.contextmanager
    def paths_on_pythonpath(paths):
        """
        Add the indicated paths to the head of the PYTHONPATH environment
        variable so that subprocesses will also see the packages at
        these paths.

        Do this in a context that restores the value on exit.
        """
        nothing = object()
        orig_pythonpath = os.environ.get('PYTHONPATH', nothing)
        current_pythonpath = os.environ.get('PYTHONPATH', '')
        try:
            prefix = os.pathsep.join(paths)
            to_join = filter(None, [prefix, current_pythonpath])
            new_path = os.pathsep.join(to_join)
            if new_path:
                os.environ['PYTHONPATH'] = new_path
            yield
        finally:
            if orig_pythonpath is nothing:
                os.environ.pop('PYTHONPATH', None)
            else:
                os.environ['PYTHONPATH'] = orig_pythonpath

    @staticmethod
    def install_dists(dist):
        """
        Install the requirements indicated by self.distribution and
        return an iterable of the dists that were built.
        """
        ir_d = dist.fetch_build_eggs(dist.install_requires or [])
        tr_d = dist.fetch_build_eggs(dist.tests_require or [])
        return itertools.chain(ir_d, tr_d)

    def run(self):
        installed_dists = self.install_dists(self.distribution)

        cmd = ' '.join(self._argv)
        if self.dry_run:
            self.announce('skipping "%s" (dry run)' % cmd)
            return

        self.announce('running "%s"' % cmd)

        paths = map(operator.attrgetter('location'), installed_dists)
        with self.paths_on_pythonpath(paths):
            with self.project_on_sys_path():
                self.run_tests()

    def run_tests(self):
        # Purge modules under test from sys.modules. The test loader will
        # re-import them from the build location. Required when 2to3 is used
        # with namespace packages.
        if six.PY3 and getattr(self.distribution, 'use_2to3', False):
            module = self.test_suite.split('.')[0]
            if module in _namespace_packages:
                del_modules = []
                if module in sys.modules:
                    del_modules.append(module)
                module += '.'
                for name in sys.modules:
                    if name.startswith(module):
                        del_modules.append(name)
                list(map(sys.modules.__delitem__, del_modules))

        exit_kwarg = {} if sys.version_info < (2, 7) else {"exit": False}
        unittest_main(
            None, None, self._argv,
            testLoader=self._resolve_as_ep(self.test_loader),
            testRunner=self._resolve_as_ep(self.test_runner),
            **exit_kwarg
        )

    @property
    def _argv(self):
        return ['unittest'] + self.test_args

    @staticmethod
    def _resolve_as_ep(val):
        """
        Load the indicated attribute value, called, as a as if it were
        specified as an entry point.
        """
        if val is None:
            return
        parsed = EntryPoint.parse("x=" + val)
        return parsed.resolve()()