#! /usr/bin/env python3
#
# Copyright 2009-2024 The VOTCA Development Team (http://www.votca.org)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import fnmatch
import json
import math
import os
import shutil
import subprocess
import sys
import pathlib
import time

import h5py

VERSION = '@PROJECT_VERSION@ #VOTCA_GIT_ID#'
PROGTITLE = 'THE VOTCA::XTP Benchmark'
PROGDESCR = 'Runs a benchmark on a set of geometries'
VOTCAHEADER = '''\
==================================================
========   VOTCA (http://www.votca.org)   ========
==================================================

{progtitle}

please read and cite: @PROJECT_CITATION@
and submit bugs to @PROJECT_CONTACT@
benchmark, version {version}

'''.format(version=VERSION, progtitle=PROGTITLE)


def xxquit(what=''):
    if what != '':
        print("ERROR: {what}".format(what=what))
    sys.exit(1)


def ReadEnergyFromHDF5(filename):
    with h5py.File(filename, 'r') as orb:
        arr = orb['QMdata']['BSE_singlet']['eigenvalues'][()]
    return arr.flatten()


def ReadIndividualTimingsFromLog(logfile):

    steps = ["DFT", "GW", "BSE"]
    result = {}
    with open(logfile, 'r') as f:
        for line in f:
            for step in steps:
                if f"{step} calculation took " in line:
                    if step in result:
                        xxquit(what=f'Logfile {logfile} is corrupted')
                    else:
                        result[step] = float(line.split()[-2])
    return result


def HasGPUs():
    helpstring = subprocess.check_output(['xtp_tools', '-h'])
    return "--gpus" in helpstring.decode('utf8')


class cd:
    """Context manager for changing the current working directory"""

    def __init__(self, newPath):
        self.newPath = os.path.expanduser(newPath)

    def __enter__(self):
        self.savedPath = os.getcwd()
        os.chdir(self.newPath)

    def __exit__(self, etype, value, traceback):
        os.chdir(self.savedPath)

# =============================================================================
# PROGRAM OPTIONS
# =============================================================================


class XtpHelpFormatter(argparse.HelpFormatter):
    def _format_usage(self, usage, action, group, prefix):
        return VOTCAHEADER


progargs = argparse.ArgumentParser(prog='xtp_benchmark',
                                   formatter_class=lambda prog: XtpHelpFormatter(
                                       prog, max_help_position=70),
                                   description=PROGDESCR)

progargs.add_argument('-t', '--threads',
                      type=int,
                      default=1,
                      help='Number of OPENMP threads')

progargs.add_argument('-c', '--cmdoptions', nargs="+",
                      default="gwbse_engine.gwbse.gw.mode=G0W0",
                      help="Modify options for dftgwbse via command line by e.g. '-c xmltag.subtag=value'. Use whitespace to separate multiple options")

progargs.add_argument('-r', '--run',
                      action='store_const',
                      const=1,
                      default=0,
                      help='Run benchmark')

progargs.add_argument('-s', '--setup',
                      action='store_const',
                      const=1,
                      default=0,
                      help='Setup Benchmark')

progargs.add_argument('-i', '--input_folder',
                      default="Geometries/Geometries_ThielSet",
                      help='Folder to take geometries from')

progargs.add_argument('-o', '--output_folder',
                      default="RUN",
                      help='Folder to write benchmark to')

progargs.add_argument('-a', '--analyze',
                      action='store_const',
                      const=1,
                      default=0,
                      help='Graphically analyse results from outputfolder')

progargs.add_argument('--reference',
                      help='In combination with analyze options compare to a reference')

HasGPUs = HasGPUs()
if HasGPUs:
    progargs.add_argument('-g', '--gpus',
                          required=False,
                          type=int,
                          default=0,
                          help='Number of GPUs to use')

OPTIONS = progargs.parse_args()


if OPTIONS.setup:
    print("Creating benchmark directory: {}".format(OPTIONS.output_folder))
    if os.path.isdir(OPTIONS.output_folder):
        xxquit("Folder '{}' already exists".format(OPTIONS.output_folder))
    os.mkdir(OPTIONS.output_folder)

    input_molecules = []

    pattern = "*.xyz"

    for path, subdirs, files in os.walk(OPTIONS.input_folder):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                input_molecules.append([path, name])

    for mol in input_molecules:
        newpath = pathlib.Path(OPTIONS.output_folder) / \
            pathlib.Path(mol[1]).stem
        os.mkdir(newpath)
        shutil.copyfile(pathlib.Path(
            mol[0])/pathlib.Path(mol[1]), newpath/"system.xyz")

    print("Created folders for:\n{}".format(
        "\n".join(m[1] for m in input_molecules)))

if OPTIONS.run:

    print("\nStarting benchmark with {} threads".format(OPTIONS.threads))
    if HasGPUs:
        print(" and {} GPUs".format(OPTIONS.gpus))

    cmd_options = "job_name=system "
    if OPTIONS.cmdoptions:
        if isinstance(OPTIONS.cmdoptions, str):
            cmd_options += OPTIONS.cmdoptions
        else:
            cmd_options += " ".join(OPTIONS.cmdoptions)
    print(f"With options: '{cmd_options}'")
    print(30*"-")

    moldata = {}
    totaltime = time.time()
    referencedata = {}
    if OPTIONS.reference:
        print("Loading reference data from '{}'".format(OPTIONS.reference))
        with open(OPTIONS.reference, 'r') as fp:
            referencedata = json.load(fp)

    for mol in os.listdir(OPTIONS.output_folder):
        print("Running {}".format(mol).ljust(30), end="\r")
        with cd(pathlib.Path(OPTIONS.output_folder)/pathlib.Path(mol)):

            start = time.time()
            if(HasGPUs):
                cmd = 'xtp_tools -e dftgwbse -t {} -g {} -c {} > dftgwbse.log'.format(
                    OPTIONS.threads, OPTIONS.gpus, cmd_options)
            else:
                cmd = 'xtp_tools -e dftgwbse -t {} -c {} > dftgwbse.log'.format(
                    OPTIONS.threads, cmd_options)
            os.system(cmd)
            dt = time.time() - start
            molinfo = {}
            molinfo["duration[s]"] = dt
            print("{}".format(mol).ljust(30) +
                  " Duration: \t{:8.2f} seconds.".format(dt), end='')
            if not os.path.isfile("system.orb"):
                print(" Computation Failed")
                molinfo["Status"] = "Failed"
            else:
                S = ReadEnergyFromHDF5("system.orb")
                print(" S1 = {:3.6f}[Hrt]".format(S[0]))
                molinfo["Status"] = "Success"
                for level in range(S.shape[0]):
                    molinfo[f"S{level+1}[Hrt]"] = S[level]

                molinfo.update(ReadIndividualTimingsFromLog("dftgwbse.log"))

        moldata[mol] = molinfo

    print(30*"-")
    dttotal = time.time() - totaltime
    print("Total time: {:8.2f}".format(dttotal))
    goodruns = 0
    for _, result in moldata.items():
        if result["Status"] == "Success":
            goodruns += 1
    print("{:1.2f}% of runs completed successfully".format(
        float(goodruns)/float(len(moldata))*100.0))

    resultsfile = OPTIONS.output_folder+'/result.json'
    if(os.path.isfile(resultsfile)):
        with open(resultsfile, "r") as f:
            oldresults = json.load(f)
        oldresults.update(moldata)
        moldata = oldresults
        print(f"Updating benchmark data in '{resultsfile}'")
    else:
        print(f"Writing benchmark data to '{resultsfile}'")
    with open(resultsfile, 'w') as fp:
        json.dump(moldata, fp, sort_keys=True, indent=4)


if OPTIONS.analyze:
    resultsfile = OPTIONS.output_folder+'/result.json'
    reference = None
    with open(resultsfile, "r") as f:
        result = json.load(f)

    if OPTIONS.reference:
        with open(OPTIONS.reference, "r") as f:
            reference = json.load(f)

    def print_timing(molname, keyword):
        new_value = float(result[molname][keyword])
        if reference is None:
            print(f"{keyword}:{new_value:.2f}")
        else:
            new_value = float(result[molname][keyword])
            ref_value = float(reference[molname][keyword])
            change = (new_value-ref_value)/ref_value
            print(
                f"{keyword}: {new_value:.2f}\tref: {ref_value:.2f}\t change:{change*100:.2f}%")

    def print_energy(molname):
        for level in range(1, 6):
            new_value = float(result[molname][f"S{level}[Hrt]"])
            if reference == None:
                print(f"S{level}[Hrt]: {new_value:.2f}")
            else:
                new_value = float(result[molname][f"S{level}[Hrt]"])
                ref_value = float(reference[molname][f"S{level}[Hrt]"])
                check = "yes"
                if not math.isclose(new_value, ref_value, rel_tol=1e-4, abs_tol=1e-4):
                    check = "WARNING result not close"
                print(
                    f"S{level}[Hrt]: {new_value:.5f}\tref: {ref_value:.5f}\tclose: {check}")

    for molname, data in result.items():
        print(molname, " Timings in [s]")
        if data["Status"] == "Success":
            print_energy(molname)
            print_timing(molname, "DFT")
            print_timing(molname, "GW")
            print_timing(molname, "BSE")
            print_timing(molname, "duration[s]")
        else:
            print("not successful")
        print(30*"-")
