Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dataclasses instead of pytools.Record #40

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions codepy/bpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def expose_vector_type(self, name, py_name=None):
if py_name is None:
py_name = name

from cgen import (Block, Typedef, Line, Statement, Value)
from cgen import Block, Line, Statement, Typedef, Value

self.init_body.append(
Block([
Expand Down Expand Up @@ -134,8 +134,7 @@ def generate(self):
module line-by-line.
"""

from cgen import Block, Module, Include, Line, Define, \
PrivateNamespace
from cgen import Block, Define, Include, Line, Module, PrivateNamespace

body = []

Expand Down
5 changes: 3 additions & 2 deletions codepy/cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import cgen


"""Convenience interface for using CodePy with CUDA"""


Expand Down Expand Up @@ -70,8 +72,7 @@ def compile(self, host_toolchain, nvcc_toolchain,
host_code = "{}\n".format(self.boost_module.generate())
device_code = "{}\n".format(self.generate())

from codepy.jit import compile_from_string
from codepy.jit import link_extension
from codepy.jit import compile_from_string, link_extension

local_host_kwargs = kwargs.copy()
local_host_kwargs.update(host_kwargs)
Expand Down
11 changes: 6 additions & 5 deletions codepy/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
__copyright__ = "Copyright (C) 2009 Andreas Kloeckner"


from pytools import memoize
import numpy
from cgen import POD, Value, dtype_to_ctype

from pytools import memoize


class Argument:
def __init__(self, dtype, name):
Expand Down Expand Up @@ -46,11 +47,11 @@ def struct_char(self):


def get_elwise_module_descriptor(arguments, operation, name="kernel"):
from codepy.bpl import BoostPythonModule
from cgen import (
POD, Block, For, FunctionBody, FunctionDeclaration, Include, Initializer,
Line, Statement, Struct, Value)

from cgen import FunctionBody, FunctionDeclaration, \
Value, POD, Struct, For, Initializer, Include, Statement, \
Line, Block
from codepy.bpl import BoostPythonModule

S = Statement # noqa: N806

Expand Down
30 changes: 21 additions & 9 deletions codepy/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@
THE SOFTWARE.
"""

import logging
from dataclasses import dataclass
from typing import List, NamedTuple

from codepy import CompileError
from pytools import Record

import logging

logger = logging.getLogger(__name__)


def _erase_dir(dir):
from os import listdir, unlink, rmdir
from os import listdir, rmdir, unlink
from os.path import join
for name in listdir(dir):
unlink(join(dir, name))
Expand Down Expand Up @@ -211,8 +214,16 @@ class _InvalidInfoFile(RuntimeError):
pass


class _SourceInfo(Record):
pass
class _Dependency(NamedTuple):
name: str
mtime: int
md5: str


@dataclass(frozen=True)
class _SourceInfo:
dependencies: List[NamedTuple]
source_name: str


def compile_from_string(toolchain, name, source_string,
Expand Down Expand Up @@ -308,10 +319,9 @@ def get_file_md5sum(fname):
return checksum.hexdigest()

def get_dep_structure(source_paths):
deps = list(toolchain.get_dependencies(source_paths))
deps.sort()
return [(dep, os.stat(dep).st_mtime, get_file_md5sum(dep)) for dep in deps
if dep not in source_paths]
deps = toolchain.get_dependencies(source_paths)
return [_Dependency(dep, os.stat(dep).st_mtime, get_file_md5sum(dep))
for dep in sorted(deps) if dep not in source_paths]

def write_source(name):
for i, source in enumerate(source_string):
Expand Down Expand Up @@ -479,6 +489,8 @@ def link_extension(toolchain, objects, mod_name, cache_dir=None,


from pytools import MovedFunctionDeprecationWrapper # noqa: E402

from codepy.toolchain import guess_toolchain as _gtc # noqa: E402


guess_toolchain = MovedFunctionDeprecationWrapper(_gtc)
4 changes: 2 additions & 2 deletions codepy/libraries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def search_on_path(filenames):
"""Find file on system path."""
# http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52224

from os.path import exists, join, abspath
from os import pathsep, environ
from os import environ, pathsep
from os.path import abspath, exists, join

search_path = environ["PATH"]

Expand Down
112 changes: 75 additions & 37 deletions codepy/toolchain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Toolchains for Just-in-time Python extension compilation."""


__copyright__ = """
"Copyright (C) 2008,9 Andreas Kloeckner, Bryan Catanzaro
"""
Expand All @@ -25,28 +24,40 @@
THE SOFTWARE.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from typing import AbstractSet, Any, List

from codepy import CompileError
from pytools import Record


class Toolchain(Record):
@dataclass(frozen=True)
class Toolchain(ABC):
"""Abstract base class for tools used to link dynamic Python modules."""

def __init__(self, *args, **kwargs):
if "features" not in kwargs:
kwargs["features"] = set()
Record.__init__(self, *args, **kwargs)
#: A list of directories where libraries are found.
library_dirs: List[str]
#: A list of libraries used.
libraries: List[str]
#: A list of directories from which to include header files.
include_dirs: List[str]

features: AbstractSet[str]

def copy(self, **kwargs: Any) -> "Toolchain":
from warnings import warn
warn(f"'{type(self).__name__}.copy' is deprecated. This is now a "
"dataclass and should be used with the standard 'replace'.",
DeprecationWarning, stacklevel=2)

return replace(self, **kwargs)

@abstractmethod
def get_version(self):
"""Return a string describing the exact version of the tools (compilers etc.)
involved in this toolchain.

Implemented by subclasses.
"""

raise NotImplementedError

def abi_id(self):
"""Return a picklable Python object that describes the ABI (Python version,
compiler versions, etc.) against which a Python module is compiled.
Expand Down Expand Up @@ -80,50 +91,38 @@ def add_library(self, feature, include_dirs, library_dirs, libraries):

self.libraries = libraries + self.libraries

@abstractmethod
def get_dependencies(self, source_files):
"""Return a list of header files referred to by *source_files.

Implemented by subclasses.
"""

raise NotImplementedError
"""Return a list of header files referred to by *source_files*."""

@abstractmethod
def build_extension(self, ext_file, source_files, debug=False):
"""Create the extension file *ext_file* from *source_files*
by invoking the toolchain. Raise :exc:`~codepy.jit.CompileError` in
case of error.

If *debug* is True, print the commands executed.

Implemented by subclasses.
"""

raise NotImplementedError

@abstractmethod
def build_object(self, obj_file, source_files, debug=False):
"""Build a compiled object *obj_file* from *source_files*
by invoking the toolchain. Raise :exc:`CompileError` in
case of error.

If *debug* is True, print the commands executed.

Implemented by subclasses.
"""

raise NotImplementedError

@abstractmethod
def link_extension(self, ext_file, object_files, debug=False):
"""Create the extension file *ext_file* from *object_files*
by invoking the toolchain. Raise :exc:`CompileError` in
case of error.

If *debug* is True, print the commands executed.

Implemented by subclasses.
"""

raise NotImplementedError

@abstractmethod
def with_optimization_level(self, level, **extra):
"""Return a new Toolchain object with the optimization level
set to `level` , on the scale defined by the gcc -O option.
Expand All @@ -133,16 +132,36 @@ def with_optimization_level(self, level, **extra):
simply ignore it.

Level may also be "debug" to specify a debug build.

Implemented by subclasses.
"""

raise NotImplementedError


# {{{ gcc-like tool chain

@dataclass(frozen=True)
class GCCLikeToolchain(Toolchain):
#: Path to the C compiler.
cc: str
#: Path to linker.
ld: str

#: A list of flags to pass to the C compiler.
cflags: List[str]
#: A list of linker flags.
ldflags: List[str]

#: A list of defines to pass to the C compiler.
defines: List[str]
# A list of variables to undefine.
undefines: List[str]

#: Extension for shared library generated by the compiler.
so_ext: str
#: Extension of the object file generated by the compiler.
o_ext: str

def _cmdline(self, source_files, object=False):
raise NotImplementedError

def get_version(self):
result, stdout, stderr = call_capture_output([self.cc, "--version"])
if result != 0:
Expand Down Expand Up @@ -233,6 +252,7 @@ def link_extension(self, ext_file, object_files, debug=False):

# {{{ gcc toolchain

@dataclass(frozen=True)
class GCCToolchain(GCCLikeToolchain):
def get_version_tuple(self):
ver = self.get_version()
Expand Down Expand Up @@ -288,13 +308,14 @@ def remove_prefix(flags, prefix):
if level >= 2 and self.get_version_tuple() >= (4, 3):
oflags.extend(["-march=native", "-mtune=native", ])

return self.copy(cflags=cflags + oflags)
return replace(self, cflags=cflags + oflags)

# }}}


# {{{ nvcc

@dataclass(frozen=True)
class NVCCToolchain(GCCLikeToolchain):
def get_version_tuple(self):
ver = self.get_version()
Expand Down Expand Up @@ -358,6 +379,9 @@ def build_object(self, ext_file, source_files, debug=False):
file=sys.stderr)
raise CompileError("module compilation failed")

def with_optimization_level(self, level, **extra):
raise NotImplementedError

# }}}


Expand Down Expand Up @@ -412,13 +436,15 @@ def _guess_toolchain_kwargs_from_python_config():
"o_ext": object_suffix,
"defines": defines,
"undefines": undefines,
"features": set(),
}


def call_capture_output(*args):
from pytools.prefork import call_capture_output
import sys

from pytools.prefork import call_capture_output

encoding = sys.getdefaultencoding()
result, stdout, stderr = call_capture_output(*args)
return result, stdout.decode(encoding), stderr.decode(encoding)
Expand Down Expand Up @@ -446,7 +472,20 @@ def guess_toolchain():
if sys.maxsize == 0x7fffffff:
kwargs["cflags"].extend(["-arch", "i386"])

return GCCToolchain(**kwargs)
return GCCToolchain(
cc=kwargs["cc"],
ld=kwargs["ld"],
library_dirs=kwargs["library_dirs"],
libraries=kwargs["libraries"],
include_dirs=kwargs["include_dirs"],
cflags=kwargs["cflags"],
ldflags=kwargs["ldflags"],
defines=kwargs["defines"],
undefines=kwargs["undefines"],
so_ext=kwargs["so_ext"],
o_ext=kwargs["o_ext"],
features=set(),
)
else:
raise ToolchainGuessError(
"Unable to determine compiler. Tried running "
Expand All @@ -469,7 +508,6 @@ def guess_nvcc_toolchain():
"undefines": gcc_kwargs["undefines"],
}
kwargs.setdefault("undefines", []).append("__BLOCKS__")
kwargs["cc"] = "nvcc"

return NVCCToolchain(**kwargs)

Expand Down
5 changes: 5 additions & 0 deletions examples/demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import cgen as c

from codepy.bpl import BoostPythonModule


mod = BoostPythonModule()

mod.add_function(
Expand All @@ -9,6 +12,8 @@
))

from codepy.toolchain import guess_toolchain


cmod = mod.compile(guess_toolchain())

print(cmod.greet())
Loading
Loading