#!/usr/bin/env python

import copy
import os
import sys
import zipfile

PATHS = {"libdir": "/usr/lib/gcj",
         "gcj":    "/usr/bin/gcj",
         "dbtool": "/usr/bin/gcj-dbtool"}

GCJFLAGS = os.environ.get("RPM_OPT_FLAGS", "").split() + [
    "-fjni", "-fPIC", "-findirect-dispatch"]
LDFLAGS = ["-Wl,-Bsymbolic"]

class Error(Exception):
    pass

def aot_compile_rpm(basedir, libdir, exclusions = ()):
    """Search basedir for jarfiles, then generate solibs and class
    mappings for them all in libdir."""
    dstdir = os.path.join(basedir, libdir.strip(os.sep))
    if not os.path.isdir(dstdir):
        os.makedirs(dstdir)
    jars = weed_jars(strip_exclusions(find_jars(basedir), exclusions))
    check_paths(jars)
    for jar in jars:
        aot_compile_jar(jar, dstdir, libdir)

def find_jars(dir):
    """Return a list of every jarfile under a directory.  Goes on
    magic rather than file extension so we hit wars, ears, rars and
    anything else they cooked up lately."""
    def visit(jars, dir, items):
        for item in items:
            path = os.path.join(dir, item)
            if os.path.islink(path) or not os.path.isfile(path):
                continue
            # could use zipfile.is_zipfile() but this is quicker
            if open(path, "r").read(2) != "PK":
                continue
            jar = JarFile(path, "r")
            # XXX ears contain jars and wars; wars contain jars, and
            # classes in the wrong place: we should recurse and cope
            # with both these cases.
            if not jar.numClasses():
                continue
            if jar.classPrefix().startswith("WEB-INF" + os.sep):
                continue
            jars[path] = jar
    jars = {}
    os.path.walk(dir, visit, jars)
    jars = [(jar.filename, jar) for jar in jars.values()]
    jars.sort()
    return [jar for path, jar in jars]

class JarFile(zipfile.ZipFile):
    def isSubsetOf(self, other):
        """Returns True if identical copies of all classes in this
        jarfile exist in the other."""
        for other_item in other.infolist():
            if not other_item.filename.endswith(".class"):
                continue
            try:
                self_item = self.getinfo(other_item.filename)
            except KeyError:
                return False
            if self_item.CRC != other_item.CRC:
                return False
        return True

    def numClasses(self):
        """Return the number of classfiles within this jarfile."""
        return len([
            item for item in self.namelist() if item.endswith(".class")])

    def classPrefix(self):
        """Return the longest prefix common to all classes."""
        return os.path.commonprefix([
            item for item in self.namelist() if item.endswith(".class")])

def strip_exclusions(jars, exclusions):
    """Remove user-excluded jars from the list.  We're really strict
    about this to ensure that dead options don't get left in
    specfiles."""
    jars = copy.copy(jars)
    for exclusion in exclusions:
        for jar in jars:
            if jar.filename == exclusion:
                jars.remove(jar)
                break
        else:
            raise Error, "%s: file does not exist or is not a jar" % exclusion
    return jars

def weed_jars(jars):
    """Remove any jarfiles that are completely contained within
    another.  This is more common than you'd think, and we only
    need one nativified copy of each class after all."""
    jars = copy.copy(jars)
    while True:
        for jar1 in jars:
            for jar2 in jars:
                if jar1 is jar2:
                    continue
                if jar1.isSubsetOf(jar2):
                    msg = "subsetted %s" % jar2.filename
                    if jar2.isSubsetOf(jar1):
                        msg += " (identical)"
                    warn(msg)
                    jars.remove(jar2)
                    break
            else:
                continue
            break
        else:
            break
        continue
    return jars

def check_paths(jars):
    """Check that each jarfile has a different basename."""
    names = {}
    for jar in jars:
        name = os.path.basename(jar.filename)
        if names.has_key(name):
            raise Error, "%s: duplicate jarname" % name
        names[name] = 1

def aot_compile_jar(jar, dir, libdir, max_classes_per_jar = 1000):
    """Generate the shared library and class mapping for one jarfile.
    If the shared library already exists then it will not be
    overwritten.  This is to allow optimizer failures and the like to
    be worked around."""
    soname = os.path.join(dir, os.path.basename(jar.filename) + ".so")
    if os.path.exists(soname):
        warn("not recreating %s" % soname)
    else:
        cleanup = []
        # prepare
        if jar.numClasses() > max_classes_per_jar:
            warn("splitting %s" % jar.filename)
            sources = split_jarfile(jar, dir, max_classes_per_jar)
            cleanup.extend(sources)
        elif jar.filename.endswith(".jar"):
            sources = [jar.filename]
        else:
            sources = [symlink_jarfile(jar.filename, dir)]
            cleanup.extend(sources)
        # compile and link
        if len(sources) == 1:
            system([PATHS["gcj"], "-shared"] +
                   GCJFLAGS + LDFLAGS +
                   [sources[0], "-o", soname])
        else:
            objects = []
            for source in sources:
                object = os.path.join(dir, os.path.basename(source) + ".o")
                system([PATHS["gcj"], "-c"] +
                       GCJFLAGS +
                       [source, "-o", object])
                objects.append(object)
                cleanup.append(object)
            system([PATHS["gcj"], "-shared"] +
                   GCJFLAGS + LDFLAGS +
                   objects + ["-o", soname])
        # clean up
        for item in cleanup:
            os.unlink(item)
    # dbtool
    dbname = soname[:soname.rfind(".")] + ".db"
    soname = os.path.join(libdir, os.path.basename(soname))
    system([PATHS["dbtool"], "-n", dbname, "64"])
    system([PATHS["dbtool"], "-f", dbname, jar.filename, soname])

def split_jarfile(src, dir, split):
    """Split large jarfiles to avoid huge assembler files."""
    jarfiles, dst = [], None
    for item in src.infolist():
        if (dst is None or item.filename.endswith(".class") and size >= split):
            if dst is not None:
                dst.close()
            path = os.path.join(dir, "%s.%d.jar" % (
                os.path.basename(src.filename), len(jarfiles) + 1))
            jarfiles.append(path)
            dst = zipfile.ZipFile(path, "w", zipfile.ZIP_STORED)
            size = 0
        dst.writestr(item, src.read(item.filename))
        size += 1
    dst.close()
    return jarfiles

def symlink_jarfile(src, dir):
    """Symlink a jarfile with a '.jar' extension so gcj knows what it is."""
    dst = os.path.join(dir, os.path.basename(src) + ".jar")
    os.symlink(src, dst)
    return dst

def system(command):
    """Execute a command."""
    prefix = os.environ.get("PS4", "+ ")
    prefix = prefix[0] + prefix
    print >>sys.stderr, prefix + " ".join(command)

    status = os.spawnv(os.P_WAIT, command[0], command)
    if status > 0:
        raise Error, "%s exited with code %d" % (command[0], status)
    elif status < 0:
        raise Error, "%s killed by signal %d" % (command[0], -status)

def warn(msg):
    """Print a warning message."""
    print >>sys.stderr, "%s: warning: %s" % (
        os.path.basename(sys.argv[0]), msg)

if __name__ == "__main__":
    try:
        name = os.environ.get("RPM_PACKAGE_NAME")
        if name is None:
            raise Error, "this script is designed for use in rpm specfiles"
        arch = os.environ.get("RPM_ARCH")
        if arch == "noarch":
            raise Error, "cannot be used on noarch packages"
        buildroot = os.environ.get("RPM_BUILD_ROOT")
        if buildroot in (None, "/"):
            raise Error, "bad $RPM_BUILD_ROOT"

        # XXX: This script should not accept options, because having
        # them it cannot be integrated into rpm.  But, gcj cannot
        # build each and every jarfile yet, so we must be able to
        # exclude until it can.
        try:
            options, exclusions = sys.argv[1:], []
            while options:
                if options.pop(0) != "--exclude":
                    raise ValueError
                exclusions.append(os.path.join(
                    buildroot, options.pop(0).lstrip(os.sep)))
        except:
            print >>sys.stderr, "usage: %s [--exclude JAR]..." % (
                os.path.basename(sys.argv[0]))
            sys.exit(1)
        
        aot_compile_rpm(
            buildroot, os.path.join(PATHS["libdir"], name), exclusions)
    except Error, e:
        print >>sys.stderr, "%s: error: %s" % (
            os.path.basename(sys.argv[0]), e)
        sys.exit(1)
