#!/usr/local/bin/python3.11

import argparse
import subprocess as sp
import os.path
import os
import shutil
import time
from pprint import pprint

from string import Template as Tm
from timeit import default_timer as timer

SUPPORTED_LANGUAGES = ['C', 'C++', 'Java', 'D', 'Rust', 'Zig', 'Go', 'V', 'Julia', 'OCaml']
TEMPLATED_SUPPORTED_LANGUAGES = ['C++', 'Java', 'D', 'Rust', 'Zig', 'V', 'Julia']
SUPPORTED_OPERATIONS = ['Check', 'Build']
DEFAULT_PROGRAM_NAME = 'linear'
C_FLAGS = ['-Wall', '-Wextra']

TABLE_TITLES = ['Lang-uage', 'Oper-ation', 'Temp-lated', "Time [s]", "Slowdown vs [Best]", 'Version', 'Exec']

LANG_IX = 0
OP_IX = 1
TEMPLATED_IX = 2
DUR_IX = 3
SPEEDUP_IX = 4
EXEC_VERSION_IX = 5
EXEC_PATH_IX = 6


def srcIdOf(lang, templated):
    if lang in TEMPLATED_SUPPORTED_LANGUAGES:
        return lang + ('-Templated' if templated else '-Untemplated')
    else:
        return lang


def opIdOf(lang, templated, op, exec_path=None):
    if lang == 'D' and exec_path is None:
        exec_path = 'dmd'
    return srcIdOf(lang, templated) + '-' + op + '-' + str(exec_path)


def fill_in_speedups(results, table_titles, op_ix):
    min_ix = None
    min_dur = None

    for ix, result in enumerate(results):
        if result[OP_IX] == op_ix:
            if min_dur is None or min_dur > result[DUR_IX]:
                min_ix = ix
                min_dur = result[DUR_IX]

    if min_ix is None:
        return

    min_lang = results[min_ix][LANG_IX]

    for ix, result in enumerate(results):
        if result[OP_IX] == op_ix:
            result[SPEEDUP_IX] = factor_str(result[DUR_IX] / min_dur) + ' [' + min_lang + ']'


def stringify_speedup(results):
    for result in results:
        result[DUR_IX] = ms_str(result[SPEEDUP_IX])  # convert to string


def row_list(durs, lang, op, exec_path, exec_version, dur, templated):
    home = os.getenv('HOME')
    if exec_path.startswith(home):
        exec_path = os.path.join("~", exec_path.lstrip(home))

    return [lang,
            op,
            'Yes' if templated else 'No',
            dur,
            None,
            exec_version,
            '`' + exec_path + '`',
            ]


def ms_str(dur):
    return "{:.3f}".format(dur)


def factor_str(factor):
    return "{:.1f}".format(factor)


def md_header(text, nr):        # Markdown header
    return '#' * nr + ' ' + text


def md_table(titles, rows):
    fill_in_speedups(rows, titles, 'Check')
    fill_in_speedups(rows, titles, 'Build')
    stringify_speedup(rows)

    result = ''

    result += '| '
    for col in titles:
        result += str(col) + ' | '
    result += '\n'

    result += '| '
    for ix, col in enumerate(range(len(titles))):
        if ix == 2:
            result += '---' + ' | '
        else:
            result += ':---:' + ' | '
    result += '\n'

    for row in rows:
        result += '| '
        for col in row:
            result += str(col) + ' | '
        result += '\n'
    return result


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--languages", "--langs", type=str,
                        default=','.join(SUPPORTED_LANGUAGES),
                        help="Languages to benchmark separated by comma")

    parser.add_argument("--operations", "--ops", type=str,
                        default=','.join(SUPPORTED_OPERATIONS),
                        help="Operations to perform separated by comma")

    parser.add_argument("--function-count", "--fc", type=int,
                        default=100,
                        help="Number of functions generated and called")

    parser.add_argument("--function-depth", "--fd", type=int,
                        default=100,
                        help="Function call depth")

    parser.add_argument("--run-count", "--rc", type=int,
                        default=1,
                        help="Number of runs for each compilation")

    args = parser.parse_args()

    args.languages = list(map(lambda x: 'OCaml' if x.lower() == 'ocaml' else x.capitalize(),
                              args.languages.split(',')))  # into a list
    filtered_languages = []
    for language in args.languages:
        if language in SUPPORTED_LANGUAGES:
            filtered_languages.append(language)
        else:
            print('Warning: Ignoring unsupported language ' + language)
    args.languages = filtered_languages

    args.operations = list(map(lambda x: x.capitalize(), args.operations.split(',')))  # into a list of capitalized names
    filtered_operations = []
    for operation in args.operations:
        if operation in SUPPORTED_OPERATIONS:
            filtered_operations.append(operation)
        else:
            print('Warning: Ignoring unsupported operation ' + operation)
    args.operations = filtered_operations

    gpaths = generate_code(args=args)

    execs = {}                  # execs by language
    durs = {}                  # time durs by compiler path

    results = []

    if 'D' in args.languages:
        if 'Check' in args.operations:
            results += benchmark_D(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False, use_dips=True)
            results += benchmark_D(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=True, use_dips=True)
        if 'Build' in args.operations:
            results += benchmark_D(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False, use_dips=True)
            results += benchmark_D(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=True, use_dips=True)

    if 'C' in args.languages:
        if 'Check' in args.operations:
            results += benchmark_GCC(lang='C', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False)
            results += benchmark_Clang(lang='C', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False)
        if 'Build' in args.operations:
            results += benchmark_GCC(lang='C', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)
            results += benchmark_Clang(lang='C', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)

    if 'C++' in args.languages:
        if 'Check' in args.operations:
            results += benchmark_GCC(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False)
            results += benchmark_Clang(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False)
            results += benchmark_GCC(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=True)
            results += benchmark_Clang(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=True)
        if 'Build' in args.operations:
            results += benchmark_GCC(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)
            results += benchmark_Clang(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)
            results += benchmark_GCC(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=True)
            results += benchmark_Clang(lang='C++', execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=True)

    if 'Go' in args.languages:
        if 'Check' in args.operations:
            results += benchmark_Go(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False)
        if 'Build' in args.operations:
            results += benchmark_Go(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)

    if 'V' in args.languages:
        if 'Build' in args.operations:
            results += benchmark_V(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)
        # TODO activate when issue https://github.com/vlang/v/issues/5818 has ben fixed:
        # results += benchmark_V(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=True)

    if 'Zig' in args.languages:
        if 'Check' in args.operations:
            results += benchmark_Zig(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False)
            results += benchmark_Zig(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=True)

    if 'Rust' in args.languages:
        if 'Check' in args.operations:
            results += benchmark_Rust(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=False)
            results += benchmark_Rust(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Check', templated=True)
        if 'Build' in args.operations:
            results += benchmark_Rust(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)
            results += benchmark_Rust(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=True)

    if 'Java' in args.languages:
        if 'Build' in args.operations:
            results += benchmark_Java(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)

    if 'Julia' in args.languages:
        if 'Build' in args.operations:
            if args.function_count * args.function_depth <= 5000:  # only for small workloads, takes to long otherwise
                results += benchmark_Julia(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False)
                results += benchmark_Julia(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=True)

    if 'OCaml' in args.languages:
        if 'Build' in args.operations:
            if args.function_count * args.function_depth <= 10000:  # only for small workloads, takes to long otherwise
                results += benchmark_OCaml(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False, bytecode=False)
            results += benchmark_OCaml(execs=execs, durs=durs, gpaths=gpaths, args=args, op='Build', templated=False, bytecode=True)

    print(md_table(TABLE_TITLES, results))


def generate_code(args):
    print(md_header('Code-generation:', 1))
    gpaths = {}
    for lang in args.languages:
        templated = False
        gpaths[srcIdOf(lang, templated)] = generate_test_program(lang=lang,
                                                                 args=args,
                                                                 templated=templated)
        if lang in TEMPLATED_SUPPORTED_LANGUAGES:
            templated = True
            gpaths[srcIdOf(lang, templated)] = generate_test_program(lang=lang,
                                                                     args=args,
                                                                     templated=templated)
    print()
    return gpaths


def benchmark_GCC(lang, execs, durs, gpaths, args, op, templated):
    results = list()
    print(md_header('GCC:', 1))
    GCC_VERSIONS = range(5, 15)
    exe_args = ['-c'] if op == 'Build' else ['-fsyntax-only']
    for gcc_version in GCC_VERSIONS:
        if lang == 'C':
            exe = shutil.which('gcc-' + str(gcc_version))
        elif lang == 'C++':
            exe = shutil.which('g++-' + str(gcc_version))
        else:
            assert(False)
        if exe is not None:
            version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[2].split('-')[0]
            dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                                   args=[exe] + C_FLAGS + exe_args,
                                   run_count=args.run_count,
                                   op=op,
                                   compiler_version=version)
            opId = opIdOf(lang, templated, op, exe)
            durs[opId] = dur_min
            results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
            print_speedup(durs, from_opId=opIdOf('D', templated, op), to_opId=opId)
    print()
    return results


def benchmark_Clang(lang, execs, durs, gpaths, args, op, templated):
    results = list()
    C_CLANG_FLAGS = C_FLAGS + ['-fno-color-diagnostics', '-fno-caret-diagnostics', '-fno-diagnostics-show-option']
    print(md_header('Clang:', 1))
    CLANG_VERSIONS = range(5, 15)
    exe_args = ['-c'] if op == 'Build' else ['-fsyntax-only']
    for clang_version in CLANG_VERSIONS:
        if lang == 'C':
            exe = shutil.which('clang-' + str(clang_version))
        elif lang == 'C++':
            exe = shutil.which('clang++-' + str(clang_version))
        else:
            assert(False)
        if exe is not None:
            version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[2].split('-')[0]
            dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                                   args=[exe] + C_CLANG_FLAGS + exe_args,
                                   run_count=args.run_count,
                                   op=op,
                                   compiler_version=version)
            results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
            opId = opIdOf(lang, templated, op, exe)
            print(md_header(opId + ':', 2))
            durs[opId] = dur_min
            print_speedup(durs, from_opId=opIdOf('D', templated, op), to_opId=opId)
    print()
    return results


def benchmark_D(execs, durs, gpaths, args, op, templated, use_dips):
    results = list()
    lang = 'D'
    d_flags = ['-dip25', '-dip1008', '-dip1000'] if use_dips else []  # use DIPs

    # DMD
    exe = shutil.which('dmd')
    if exe is not None:
        exe_args = ['-c'] if op == 'Build' else ['-o-']
        opId = opIdOf(lang, templated, op, 'dmd')
        print(md_header(opId + ':', 2))
        version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[3]
        if opId not in execs:
            execs[opId] = exe
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe] + exe_args + d_flags,
                               run_count=args.run_count,
                               op=op,
                               compiler_version=version)
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        durs[opId] = dur_min

    # LDC
    exe = shutil.which('ldmd2')
    if exe is not None:
        exe_args = ['-c'] if op == 'Build' else ['-o-']
        opId = opIdOf(lang, templated, op, 'ldmd2')
        print(md_header(opId + ':', 2))
        version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[6][1:-2]
        if opId not in execs:
            execs[opId] = exe
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe] + exe_args + d_flags,
                               run_count=args.run_count,
                               op=op,
                               compiler_version=version)
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        durs[opId] = dur_min

    # GDC
    if False:                   # disable for now until it supports larger input
        exe = shutil.which('gdc')
        if exe is not None:
            exe_args = ['-c'] if op == 'Build' else ['-fsyntax-only']
            opId = opIdOf(lang, templated, op, 'gdc')
            print(md_header(opId + ':', 2))
            version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[3]
            if opId not in execs:
                execs[opId] = exe
            dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                                   args=[exe] + exe_args,
                                   run_count=args.run_count,
                                   op=op,
                                   compiler_version=version)
            results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
            durs[opId] = dur_min

    print()

    return results


def benchmark_Go(execs, durs, gpaths, args, op, templated):
    results = list()
    lang = 'Go'
    exe = shutil.which('gccgo')
    exe_args = ['-c'] if op == 'Build' else ['-fsyntax-only', '-S']
    if exe is not None:
        execs[lang] = exe
        version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[3]
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe] + exe_args,
                               run_count=args.run_count,
                               op=op,
                               compiler_version=version)
        opId = opIdOf(lang, templated, op, exe)
        print(md_header(opId + ':', 2))
        durs[opId] = dur_min
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        print_speedup(durs, from_opId=opIdOf('D', templated, op), to_opId=opId)
        print()
    return results


def benchmark_OCaml(execs, durs, gpaths, args, op, templated, bytecode):
    results = list()
    lang = 'OCaml'
    if bytecode:
        exe = shutil.which('ocamlc')
    else:
        exe = shutil.which('ocamlopt')
    if exe is not None:
        execs[lang] = exe
        version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[0]
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe, '-c'],
                               run_count=args.run_count,
                               op=op,
                               compiler_version=version)
        opId = opIdOf(lang, templated, op, exe)
        print(md_header(opId + ':', 2))
        durs[opId] = dur_min
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        print_speedup(durs, from_opId=opIdOf('D', templated, 'Build'), to_opId=opId)
        print()
    return results


def benchmark_V(execs, durs, gpaths, args, op, templated):
    results = list()
    lang = 'V'                                   # vlang.io
    exe = shutil.which('/home/per/ware/vlang/v')  # shutil.which('v')
    if exe is not None:
        execs[lang] = exe
        version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[1]
        vlang_backends = ['c', 'js', 'x64', 'v2', 'experimental']
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe, '-cc', 'clang', '-backend', vlang_backends[2]],
                               run_count=args.run_count,
                               op=op,
                               compiler_version=version)
        opId = opIdOf(lang, templated, op, exe)
        print(md_header(opId + ':', 2))
        durs[opId] = dur_min
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        print_speedup(durs, from_opId=opIdOf('D', templated, op), to_opId=opId)
        print()
    return results


def benchmark_Zig(execs, durs, gpaths, args, op, templated):
    results = list()
    lang = 'Zig'
    exe = shutil.which('zig')
    if exe is not None:
        execs[lang] = exe
        version = sp.run([exe, 'version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[0]
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe, 'build-obj', '-fno-emit-bin'],
                               run_count=args.run_count,  # no syntax flag currently so compile to object file instead
                               op=op,
                               compiler_version=version)
        opId = opIdOf(lang, templated, op, exe)
        print(md_header(opId + ':', 2))
        durs[opId] = dur_min
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        print_speedup(durs, from_opId=opIdOf('D', templated, op), to_opId=opId)
        print()
    return results


def set_rustup_channel(channel):
    with sp.Popen(['rustup', 'default', channel],
                  stdout=sp.PIPE,
                  stderr=sp.PIPE) as proc:
        proc.communicate()


def benchmark_Rust(execs, durs, gpaths, args, op, templated):
    results = list()
    lang = 'Rust'

    rustup_exe = shutil.which('rustup')
    if rustup_exe:
        rustup_channels = ['stable', 'nightly']
    else:
        rustup_channels = [None]

    for channel in rustup_channels:
        if rustup_exe is not None:
            set_rustup_channel(channel)

        exe = shutil.which('rustc')

        if exe is not None:
            opId = opIdOf(lang, templated, op, exe)
            print(md_header(opId + ':', 2))
            if opId not in execs:
                execs[opId] = exe
                # See: https://stackoverflow.com/questions/53250631/does-rust-have-a-way-to-perform-syntax-and-semantic-analysis-without-generating/53250674#53250674
                # See: https://stackoverflow.com/questions/51485765/run-rustc-to-check-a-program-without-generating-any-files
                # Alternatives:
                # - `rustc --emit=metadata -Z no-codegen`
                # - Not yet in stable: `rustc -Z no-codegen`
                # - 'rustc', '--crate-type', 'lib', '--emit=mir', '-o', '/dev/null', '--test'
            version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[1]
            if op == 'Check':
                if channel == 'nightly':
                    check_args = ['-Z', 'no-codegen']  # TODO why is this not available on stable yet?
                else:
                    check_args = ['--emit=mir', '-o', '/dev/null']  # Used by Flycheck. Twice as slow as `-Z no-codegen`
            dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                                   args=[exe] if op == 'Build' else [exe] + check_args,  # https://github.com/rust-lang/rfcs/issues/1476
                                   run_count=args.run_count,
                                   op=op,
                                   compiler_version=version)
            durs[opId] = dur_min
            results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
            print_speedup(durs,
                          from_opId=opIdOf('D', templated, op),
                          to_opId=opId)
            print()
    return results


def benchmark_Java(execs, durs, gpaths, args, op, templated):
    results = list()
    lang = 'Java'
    exe = shutil.which('javac')
    if exe is not None:
        opId = opIdOf(lang, templated, op, exe)
        print(md_header(opId + ':', 2))
        execs[opId] = exe
        version = sp.run([exe, '-version'], stderr=sp.PIPE).stderr.decode('utf-8').split()[1]
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe, '-Xdiags:verbose'],
                               run_count=args.run_count,
                               op=op)
        durs[opId] = dur_min
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        print_speedup(durs,
                      from_opId=opIdOf('D', templated, op),
                      to_opId=opId)
        print()
    return results


def benchmark_Julia(execs, durs, gpaths, args, op, templated):
    results = list()
    lang = 'Julia'
    exe = shutil.which('julia')
    if exe is not None:
        opId = opIdOf(lang, templated, op, exe)
        print(md_header(opId + ':', 2))
        execs[opId] = exe
        version = sp.run([exe, '--version'], stdout=sp.PIPE).stdout.decode('utf-8').split()[2]
        dur_min = compile_file(path=gpaths[srcIdOf(lang, templated)],
                               args=[exe],
                               run_count=args.run_count,
                               op=op)
        durs[opId] = dur_min
        results.append(row_list(durs, lang, op, exe, version, dur_min, templated))
        print_speedup(durs,
                      from_opId=opIdOf('D', templated, op),
                      to_opId=opId)
        print()
    return results


def compile_file(path, args, run_count=1,
                 op='Check',
                 compiler_version=None):

    compiler = shutil.which(args[0])
    if compiler is None:
        print('Could not find compiler:', args[0])
        return None

    durs = []
    for _ in range(0, run_count):
        start = timer()
        # print(args + [path])
        with sp.Popen(args + [path],
                      stdout=sp.PIPE,
                      stderr=sp.PIPE) as proc:
            results = proc.communicate()
            end = timer()
            dur = end - start
            durs.append(dur)
            if results[0]:
                print('stdout:', results[0])
            if results[1]:
                print('stderr:', results[1])

    dur_min = min(durs)

    show_file = False
    compiler_version_str = " version " + str(compiler_version) if compiler_version is not None else ""
    if show_file:
        print('- {} of {} took {:1.3f} seconds (using "{}"{})'.format(op, path, dur_min, args[0],
                                                                      compiler_version_str))
    else:
        print('- {} took {:1.3f} seconds (using "{}"{})'.format(op, dur_min, args[0],
                                                                compiler_version_str))

    return dur_min


def long_types_of_lang(lang):
    if lang in ['c', 'c++', 'java']:
        return ['long']
    elif lang in ['d']:
        return ['long']
    elif lang in ['rust', 'zig', 'v']:
        return ['i64']
    elif lang == 'go':
        return ['int64']
    elif lang == 'julia':
        return ['Int64']
    elif lang == 'ocaml':
        return ['float']
    else:
        return None


def language_file_extension(lang):
    if lang == 'rust':
        return 'rs'
    elif lang == 'julia':
        return 'jl'
    elif lang == 'ocaml':
        return 'ml'
    else:
        return lang


def generate_test_program(lang,
                          args,
                          templated,
                          root_path='generated'):
    program_name = DEFAULT_PROGRAM_NAME

    llang = lang.lower()        # lowered language
    types = long_types_of_lang(llang)
    ext = language_file_extension(llang)
    dir_path = os.path.join(root_path, llang)
    os.makedirs(dir_path, exist_ok=True)
    path = os.path.join(dir_path, program_name + ('_t' if templated else '') + '.' + ext)

    start = timer()
    with open(path, 'w') as f:
        generate_test_language_specific_prefix(llang, program_name, f, templated=templated)

        generate_linear_test_function_definition_set(llang, types, args, f,
                                                     templated=templated)
        generate_test_main_header(llang, types, f, templated)
        for typ in types:
            generate_linear_test_function_variable(llang, typ, f, templated=templated)
            for findex in range(0, args.function_count):
                generate_test_function_call(llang, findex, typ, f, templated=templated)

        generate_test_language_specific_postfix(llang, types, f)
    end = timer()
    dur = (end - start)  # time dur
    print("- Generating {} took {:1.3f} seconds ({})".format(path, dur, lang))
    # print("Generated {} source file: {}".format(llang.upper(), path))

    return path  # "-betterC"


def generate_test_function_call(lang, findex, typ, f, templated):
    if lang in ["java"]:
        f.write('    ')

    if lang == "zig" and templated:  # Zig needs explicit template type
        f.write(Tm('    ${T}_sum += add_${T}_n${N}(${T}, ${N})').substitute(T=typ, N=str(findex)))
    elif lang == "v" and templated:  # Zig needs explicit template type for now. See: https://github.com/vlang/v/issues/5818
        f.write(Tm('    ${T}_sum += add_${T}_n${N}<${T}>(${N})').substitute(T=typ, N=str(findex)))
    elif lang == "ocaml":
        f.write(Tm('    let ${T}_sum = ${T}_sum +. (add_${T}_n${N} ${N}.0) in').substitute(T=typ, N=str(findex)))
    else:
        f.write(Tm('    ${T}_sum += add_${T}_n${N}(${N})').substitute(T=typ, N=str(findex)))

    if lang not in ["v", "ocaml"]:
        f.write(';')
    f.write('\n')


def generate_test_language_specific_prefix(lang, program_name, f, templated):
    # package definition
    if lang == "go":
        f.write('package ' + program_name + ';\n\n')
    # if lang == "v":
    #     f.write('import os\n\n')

    # special modules
    if lang == "rust":
        f.write('use std::process::exit;\n')
        if templated:
            f.write('use std::ops::Add;\n')

    # special class wrapping
    if lang == "java":
        f.write('class HelloWorld {\n')


def generate_test_language_specific_postfix(lang, types, f):
    if lang == "rust":
        f.write(Tm('    exit((${T}_sum % 4294967296) as i32);\n}\n').substitute(T=types[0]))
    elif lang == "v":
        f.write(Tm('    exit(int(${T}_sum))\n}\n').substitute(T=types[0]))
    elif lang == "java":
        f.write(Tm('        System.exit(${T}_sum == 42 ? 1 : 0);\n    }\n').substitute(T=types[0]))
    elif lang == "zig":
        f.write(Tm('\n}\n').substitute(T=types[0]))
    elif lang == "julia":
        f.write(Tm('''    return ${T}_sum;
end

main()
''').substitute(T=types[0]))
    elif lang == "d":
        f.write(Tm('''    return cast(int)${T}_sum;
}
''').substitute(T=types[0]))
    elif lang == "ocaml":
        f.write(Tm('    exit (if ${T}_sum = 42.0 then 1 else 0)\n').substitute(T=types[0]))
    else:
        f.write(Tm('''    return ${T}_sum;
}
''').substitute(T=types[0]))

    if lang == "java":
        f.write('}\n')            # one extra closing brace for class


def generate_linear_test_function_definition_set(lang, types, args, f,
                                                 templated):
    for typ in types:
        for findex in range(0, args.function_count):
            for fheight in range(0, args.function_depth):
                generate_test_function_definition(args, lang, typ, findex, fheight, f,
                                                  templated=templated)
            f.write('\n')
        f.write('\n')


def function_name(typ, findex, fheight):
    if fheight is not None:
        return Tm('add_${T}_n${N}_h${H}').substitute(T=typ,
                                                     N=str(findex),
                                                     H=str(fheight))
    else:
        return Tm('add_${T}_n${N}').substitute(T=typ,
                                               N=str(findex))


def generate_test_function_definition(args, lang, typ, findex, fheight, f,
                                      templated):
    if fheight == 0:
        if lang == 'rust' and templated:
            expr = 'x'          # because Rust is picky
        elif lang == 'ocaml':
            expr = 'x +. ' + str(findex) + '.0'
        else:
            expr = 'x + ' + str(findex)
    else:
        if lang == 'zig' and templated:
            xtarg = typ + ', '  # Zig needs explicit template parameter
        else:
            xtarg = ''
        if lang == 'rust' and templated:
            call = function_name(typ, findex, fheight - 1) + '(' + xtarg + 'x)'
            expr = 'x + ' + call
        elif lang == 'v' and templated:
            call = function_name(typ, findex, fheight - 1) + Tm('<${T}>(x)').substitute(T=typ)
            expr = 'x + ' + call
        elif lang == 'ocaml':
            expr = 'x +. (' + function_name(typ, findex, fheight - 1) + xtarg + ' x) +. ' + str(findex) + '.0'
        else:
            expr = 'x + ' + function_name(typ, findex, fheight - 1) + '(' + xtarg + 'x) + ' + str(findex)

    if fheight == args.function_depth - 1:
        fname = function_name(typ, findex, fheight=None)
    else:
        fname = function_name(typ, findex, fheight)

    if lang in ["c"]:
        f.write(Tm('${T} ${F}(${T} x) { return ${X}; }\n').substitute(T=typ, F=str(fname), N=str(findex), X=expr))
    elif lang in ["java"]:
        f.write(Tm('    static ${T} ${F}(${T} x) { return ${X}; }\n').substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight), X=expr))
    elif lang in ["c++"]:
        f.write(Tm('${M}${T} ${F}(${T} x) { return ${X}; }\n')
                .substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight),
                            M='template<typename T=int> ' if templated else '',
                            X=expr))
        # template<typename T1, typename T2 = int> class A;
    if lang in ["d"]:
        # See: https://forum.dlang.org/post/sfldpxiieahuiizvgjeb@forum.dlang.org
        f.write(Tm('${T} ${F}${M}(${T} x) @safe pure nothrow @nogc { return ${X}; }\n')
                .substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight), M='(T=void)' if templated else '', X=expr))
    elif lang == "rust":
        if templated:
            f.write(Tm('fn ${F}<${T} : ${R}>(x: ${T}) -> ${T} { ${X} }\n').substitute(T='T',
                                                                                      R='Copy + Add<Output = T>',
                                                                                      F=str(fname),
                                                                                      N=str(findex),
                                                                                      H=str(fheight),
                                                                                      X=expr))
        else:
            f.write(Tm('fn ${F}(x: ${T}) -> ${T} { ${X} }\n').substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight), X=expr))
    elif lang == "zig":
        if templated:
            f.write(Tm('fn ${F}(comptime T: type, x: T) T { return ${X}; }\n').substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight), X=expr))
        else:
            f.write(Tm('fn ${F}(x: ${T}) ${T} { return ${X}; }\n').substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight), X=expr))
    elif lang == "go":
        f.write(Tm('func ${F}(x ${T}) ${T} { return ${X} }\n').substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight), X=expr))
    elif lang == "ocaml":
        f.write(Tm('let ${F} x = ${X}\n').substitute(F=str(fname), X=expr))
    elif lang == "v":
        if templated:
            f.write(Tm('fn ${F}<${T}>(x ${T}) ${T} { return ${X} }\n').substitute(T='T', F=str(fname), N=str(findex), H=str(fheight), X=expr))
        else:
            f.write(Tm('fn ${F}(x ${T}) ${T} { return ${X} }\n').substitute(T=typ, F=str(fname), N=str(findex), H=str(fheight), X=expr))
    elif lang == "julia":
        f.write(Tm('function ${F}(x${QT})${QT}\n    return ${X}\nend;\n').substitute(QT=(('::' + typ) if templated else ''),
                                                                                     F=str(fname), N=str(findex), H=str(fheight), X=expr))


def generate_test_main_header(lang, types, f, templated):
    if lang in ["c", "c++"]:
        f.write('int main(__attribute__((unused)) int argc, __attribute__((unused)) char* argv[]) {\n')
    elif lang == "java":
        f.write('    public static void main(String args[]) {\n')
    elif lang == "d":
        f.write('int main(string[] args) {\n')
    elif lang == "rust":
        f.write(Tm('fn main() {\n').substitute(T=types[0]))
    elif lang == "zig":
        f.write(Tm('pub fn main() void {\n').substitute(T=types[0]))
    elif lang == "go":
        f.write(Tm('func main() ${T} {\n').substitute(T=types[0]))
    elif lang == "v":
        f.write(Tm('fn main() {\n').substitute(T=types[0]))
    elif lang == "julia":
        f.write(Tm('function main()${QT}\n').substitute(QT=(('::' + types[0]) if templated else '')))
    elif lang == "ocaml":
        f.write('let () = \n')
    else:
        assert False


def generate_linear_test_function_variable(lang, typ, f, templated):
    if lang in ["c", "c++", "d"]:
        f.write(Tm('    ${T} ${T}_sum = 0;\n').substitute(T=typ))
    elif lang in ["java"]:
        f.write(Tm('        ${T} ${T}_sum = 0;\n').substitute(T=typ))
    elif lang == "rust":
        f.write(Tm('    let mut ${T}_sum : ${T} = 0;\n').substitute(T=typ))
    elif lang == "zig":
        f.write(Tm('    var ${T}_sum: ${T} = 0;\n').substitute(T=typ))
    elif lang == "go":
        f.write(Tm('    var ${T}_sum ${T} = 0;\n').substitute(T=typ))
    elif lang == "v":
        f.write(Tm('    mut ${T}_sum := ${T}(0)\n').substitute(T=typ))
    elif lang == "julia":
        f.write(Tm('    ${T}_sum${QT} = 0;\n').substitute(T=typ, QT=(('::' + typ) if templated else '')))
    elif lang == "ocaml":
        f.write(Tm('    let ${T}_sum = 0.0 in\n').substitute(T=typ))
    else:
        assert False


def generate_test_program_2(function_count, lang, root_path, templated):
    program_name = "sample2"

    lang = lang.lower()
    types = long_types_of_lang(lang)
    ext = language_file_extension(lang)
    dir_path = os.path.join(root_path, lang)
    os.makedirs(dir_path, exist_ok=True)
    path = os.path.join(dir_path, program_name + "." + ext)

    start = timer()
    with open(path, 'w') as f:

        # package definition
        if lang == "go":
            f.write('''package ''' + program_name + ''';

''')

        # standard io module
        if lang in ["c"]:
            f.write('''#include <stdio.h>
''')
        if lang in ["c++"]:
            f.write('''#include <iostream>
''')
        if lang == "d":
            f.write('''import std.stdio;
''')
        if lang == "rust":
            f.write('''use std::io;
''')
        if lang == "go":
            f.write('''import "fmt";

''')
        if lang == "v":
            f.write('''import os
''')

        # special modules
        if lang == "rust":
            f.write('''use std::process::exit;
''')

        for typ in types:
            for findex in range(0, function_count):
                if lang in ["c", "c++"]:
                    f.write(Tm('''${T} add_${T}_n${N}(${T} x) { return x + ${N}; }
''').substitute(T=typ, N=str(findex)))
                if lang in ["d"]:
                    f.write(Tm('''${T} add_${T}_n${N}(${T} x) pure { return x + ${N}; }
''').substitute(T=typ, N=str(findex)))
                elif lang == "rust":
                    f.write(Tm('''fn add_${T}_n${N}(x: ${T}) -> ${T} { x + ${N} }
''').substitute(T=typ, N=str(findex)))
                elif lang == "zig":
                    f.write(Tm('''fn add_${T}_n${N}(x: ${T}) ${T} { return x + ${N}; }
''').substitute(T=typ, N=str(findex)))
                elif lang == "go":
                    f.write(Tm('''func add_${T}_n${N}(x ${T}) ${T} { return x + ${N} }
''').substitute(T=typ, N=str(findex)))
                elif lang == "ocaml":
                    f.write(Tm('''let add_${T}_n${N} x = x +. ${N}.0
''').substitute(T=typ, N=str(findex)))
                elif lang == "v":
                    f.write(Tm('''fn add_${T}_n${N}(x ${T}) ${T} { return x + ${N} }
''').substitute(T=typ, N=str(findex)))
                elif lang == "julia":
                    f.write(Tm('''function add_${T}_n${N}(x${QT})${QT}
    return x + ${N}
end;
''').substitute(QT=(('::' + typ) if templated else ''), N=str(findex)))
                    f.write('\n')

        # MAIN HEADER
        if lang in ["c", "c++"]:
            f.write('''int main(__attribute__((unused)) int argc, __attribute__((unused)) char* argv[])
{
''')
        elif lang == "d":
            f.write('''int main(string[] args)
{
''')
        elif lang == "rust":
            f.write(Tm('''fn main() {
''').substitute(T=types[0]))
        elif lang == "zig":
            f.write(Tm('''pub fn main() void {
''').substitute(T=types[0]))
        elif lang == "go":
            f.write(Tm('''func main() ${T} {
''').substitute(T=types[0]))
        elif lang == "v":
            f.write(Tm('''fn main() ${T} {
''').substitute(T=types[0]))
        elif lang == "julia":
            f.write(Tm('''function main()::${T}
''').substitute(T=types[0]))
        else:
            assert False

        # CALCULATE
        for typ in types:
            if lang in ["c", "c++", "d"]:
                f.write(Tm('''    ${T} ${T}_sum = 0;
''').substitute(T=typ))
            elif lang == "rust":
                f.write(Tm('''    let mut ${T}_sum : ${T} = 0;
''').substitute(T=typ))
            elif lang == "zig":
                f.write(Tm('''    var ${T}_sum: ${T} = 0;
''').substitute(T=typ))
            elif lang == "go":
                f.write(Tm('''    var ${T}_sum ${T} = 0;
''').substitute(T=typ))
            elif lang == "v":
                f.write(Tm('''    var ${T}_sum ${T} = 0;
''').substitute(T=typ))
            elif lang == "julia":
                f.write(Tm('''    ${T}_sum${QT} = 0;
''').substitute(QT=(('::' + typ) if templated else '')))
            else:
                assert False

            for findex in range(0, function_count):
                f.write(Tm('''    ${T}_sum += add_${T}_n${N}(${N});
''').substitute(T=typ, N=str(findex)))

        if lang == "rust":
            f.write(Tm('''    exit(${T}_sum);
}
''').substitute(T=types[0]))
        elif lang == "zig":
            f.write(Tm('''
}
''').substitute(T=types[0]))
        elif lang == 'ocaml':
            f.write(Tm('''    ${T}_sum''').substitute(T=types[0]))
        elif lang == "julia":
            f.write(Tm('''    return ${T}_sum;
end

main()
''').substitute(T=types[0]))
        elif lang == "d":
            f.write(Tm('''    return cast(int)${T}_sum;
}
''').substitute(T=types[0]))
        else:
            f.write(Tm('''    return ${T}_sum;
}
''').substitute(T=types[0]))

    end = timer()
    dur = (end - start)  # time dur
    print("- Generating {} took {:1.3f} seconds ({})".format(path, dur, lang))

    # print("Generated {} source file: {}".format(lang.upper(), path))

    return path  # "-betterC"


def print_speedup(durs, from_opId, to_opId):
    if (from_opId in durs) and (to_opId in durs):
        print("- Speedup of {} over {}: {:.2f}".format(from_opId,
                                                       to_opId,
                                                       durs[to_opId] / durs[from_opId]))


if __name__ == '__main__':
    main()
