optimize the interface with python
[liba.git] / setup.py
blobfd99ceac5e85aa25737ad290c29f78494f3b3399
1 #!/usr/bin/env python
2 try:
3 from setuptools.command.build_ext import build_ext
4 from setuptools import setup, Extension
5 except ImportError:
6 from distutils.command.build_ext import build_ext
7 from distutils.extension import Extension
8 from distutils.core import setup
9 try:
10 USE_CYTHON = True
11 from Cython.Build import cythonize
12 except Exception:
13 USE_CYTHON = False
14 from argparse import ArgumentParser
15 from sys import byteorder
16 from re import findall
17 import os, sys, ctypes
18 import ctypes.util
21 def strtobool(s):
22 if s.lower() in ("1", "y", "yes", "true"):
23 return 1
24 return 0
27 os.chdir(os.path.dirname(os.path.abspath(sys.argv[0])))
28 if len(sys.argv) < 2:
29 sys.argv += ["--quiet", "build_ext", "--inplace"]
30 LIBA_OPENMP = os.environ.get("LIBA_OPENMP")
31 if LIBA_OPENMP:
32 LIBA_OPENMP = strtobool(LIBA_OPENMP)
33 LIBA_FLOAT = os.environ.get("LIBA_FLOAT")
34 if LIBA_FLOAT:
35 LIBA_FLOAT = int(LIBA_FLOAT)
36 else:
37 LIBA_FLOAT = 8
40 def check_math(text=""):
41 if sys.platform == "win32":
42 path_libm = ctypes.util.find_library("ucrtbase")
43 if not path_libm:
44 path_libm = ctypes.util.find_msvcrt()
45 else:
46 path_libm = ctypes.util.find_library("m")
47 try:
48 libm = ctypes.CDLL(path_libm)
49 except Exception:
50 return text
51 for func in (
52 "expm1",
53 "log1p",
54 "hypot",
55 "atan2",
56 "csqrt",
57 "cpow",
58 "cexp",
59 "clog",
60 "csin",
61 "ccos",
62 "ctan",
63 "csinh",
64 "ccosh",
65 "ctanh",
66 "casin",
67 "cacos",
68 "catan",
69 "casinh",
70 "cacosh",
71 "catanh",
73 name = "A_HAVE_" + func.upper()
74 if LIBA_FLOAT == 0x10:
75 func += "l"
76 if LIBA_FLOAT == 0x04:
77 func += "f"
78 try:
79 libm[func]
80 except Exception:
81 continue
82 text += "#define %s 1\n" % (name)
83 return text
86 def configure(config):
87 with open("setup.cfg", "r") as f:
88 version = findall(r"version = (\S+)", f.read())[0]
89 major, minor, patch = findall(r"(\d+).(\d+).(\d+)", version)[0]
90 order = {"little": 1234, "big": 4321}.get(byteorder)
91 vsize = ctypes.sizeof(ctypes.c_void_p(0))
92 text = """/* autogenerated by setup.py */
93 #define A_VERSION "{}"
94 #define A_VERSION_MAJOR {}
95 #define A_VERSION_MINOR {}
96 #define A_VERSION_PATCH {}
97 #if !defined A_SIZE_POINTER
98 #define A_SIZE_POINTER {}
99 #endif /* A_SIZE_POINTER */
100 #if !defined A_BYTE_ORDER
101 #define A_BYTE_ORDER {}
102 #endif /* A_BYTE_ORDER */
103 {}""".format(
104 version, major, minor, patch, vsize, order, check_math()
106 with open(config, "wb") as f:
107 f.write(text.encode("UTF-8"))
110 parser = ArgumentParser(add_help=False)
111 parser.add_argument("-b", "--build-base", default="build")
112 parser.add_argument("-O", "--link-objects")
113 args = parser.parse_known_args(sys.argv[1:])
114 base = args[0].build_base
116 sources, objects = [], []
117 config_h = os.path.join(base, "a.setup.h")
118 a_have_h = os.path.relpath(config_h, "include/a")
119 define_macros = [("A_HAVE_H", '"' + a_have_h + '"'), ("A_EXPORTS", None)]
120 if LIBA_FLOAT != 8:
121 define_macros += [("A_SIZE_FLOAT", LIBA_FLOAT)]
122 if USE_CYTHON and os.path.exists("python/src/a.pyx"):
123 sources += ["python/src/a.pyx"]
124 elif os.path.exists("python/src/a.c"):
125 sources += ["python/src/a.c"]
126 if not os.path.exists(base):
127 os.makedirs(base)
128 configure(config_h)
130 for dirpath, dirnames, filenames in os.walk("src"):
131 if args[0].link_objects:
132 break
133 for filename in filenames:
134 source = os.path.join(dirpath, filename)
135 if os.path.splitext(source)[-1] == ".c":
136 sources.append(source)
138 ext_modules = [
139 Extension(
140 name="liba",
141 language="c",
142 sources=sources,
143 include_dirs=["include"],
144 define_macros=define_macros,
147 if USE_CYTHON:
148 ext_modules = cythonize(
149 ext_modules,
150 quiet=True,
154 class Build(build_ext): # type: ignore
155 def build_extensions(self):
156 if self.compiler.compiler_type == "msvc":
157 for e in self.extensions:
158 if LIBA_OPENMP:
159 e.extra_compile_args += ["/openmp"]
160 if not self.compiler.compiler_type == "msvc":
161 for e in self.extensions:
162 if LIBA_OPENMP:
163 e.extra_compile_args += ["-fopenmp"]
164 e.extra_link_args += ["-fopenmp"]
165 if self.compiler.compiler_type == "mingw32":
166 self.compiler.define_macro("__USE_MINGW_ANSI_STDIO", 1)
167 for e in self.extensions:
168 if e.language == "c++":
169 e.extra_link_args += ["-static-libstdc++"]
170 e.extra_link_args += ["-static-libgcc"]
171 e.extra_link_args += [
172 "-Wl,-Bstatic,--whole-archive",
173 "-lwinpthread",
174 "-Wl,--no-whole-archive",
176 build_ext.build_extensions(self)
179 setup(
180 ext_modules=ext_modules, # type: ignore
181 cmdclass={"build_ext": Build},