#!/usr/bin/env python3
# -*- coding: utf-8 -*-

#***********************************************************************
# This file is part of OpenMolcas.                                     *
#                                                                      *
# OpenMolcas is free software; you can redistribute it and/or modify   *
# it under the terms of the GNU Lesser General Public License, v. 2.1. *
# OpenMolcas is distributed in the hope that it will be useful, but it *
# is provided "as is" and without any express or implied warranties.   *
# For more details see the full text of the license in the file        *
# LICENSE or in <http://www.gnu.org/licenses/>.                        *
#                                                                      *
# Copyright (C) 2017, Ignacio Fdez. Galván                             *
#***********************************************************************

import sys
from glob import glob
from os.path import abspath, join, getmtime
from os import environ
import re
from fnmatch import translate
import json

try:
  from colorama import init, Fore, Style
except ImportError:
  def init(*args, **kwargs):
    pass
  class Dummy(object):
    pass
  Fore = Dummy()
  Fore.RED = ''
  Fore.BLUE = ''
  Fore.GREEN = ''
  Style = Dummy()
  Style.RESET_ALL = ''

helptext = '''
{1}################################################################################
#                  Querying the OpenMolcas basis set library                   #
################################################################################{0}

Use with a command-line argument that is a basis label in Molcas format:

       {2}[element].[name].[author].[primitives].[contraction].[aux. labels]*
          (1)     (2)     (3)        (4)           (5)          (6){0}

trailing dots can be omitted.

Examples:

In the simplest form, just write an element symbol to get a list of basis sets
for that element:
  {3}Fe{0}               lists all basis sets for iron
  
Additional fields can be included to restrict the search:
  {3}Fe..jensen{0}       lists all basis sets for iron, with "jensen" in the author
                   field (case insensitive)
  {3}Fe.....ECP{0}       lists all basis sets for silver that use ECP
                   
Wildcards are supported:
  {3}Fe.*cc-p*v?z*{0}    lists all basis sets for iron belonging to the "cc" families
                   
A single space matches empty fields (needs quotes):
  {3}"Fe..... ."{0}      lists all basis sets for silver with an empty 6th field
                   (probably all-electron)
                   
Minimum number of functions for ANO-type basis sets can be specified:
  {3}Fe....8s7p7d4f{0}   lists all basis sets for iron with a [8s7p7d4f] contraction or
                   larger if the basis set supports using fewer functions

Exact matches are printed in green, "fuzzy" matches in red
'''.format(Style.RESET_ALL, Fore.BLUE, Fore.RED, Fore.GREEN)

compat = False
try:
  flag = sys.argv[1]
except IndexError:
  flag = None
  pass
if (flag == '-bw'):
  sys.argv.pop(1)
  compat = True

# First get the template from the command line,
# Use colors in terminal, not if output is redirected to file
if (sys.stdout.isatty() and not compat):
  init()
else:
  init(strip=True)

atom_basis = re.compile(r'^/(\S*)\s*$')
alias = re.compile(r'^\s*([^#]\S*)\s+(\S*)\s*$')
props = ['atom', 'name', 'auth', 'prim', 'func', 'aux1', 'aux2']
angs = ['s', 'p', 'd', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'q', 'r', 't', 'u']

# Parse a basis set label into a dict:
# - The label has fields separated by "." in specific order
# - The fields correspond to the elements in "prop"
# - Missing fields are empty
def parse_basis(label):
  parts = label.rstrip('.').split('.')
  parts.extend((len(props)-len(parts))*[''])
  return {k: v for k, v in zip(props, parts)}

# Decide if a given field in a basis matches the template
# A field matches if:
# - The corresponding template field is empty
# - The field is equal to the template (case-insensitive, wildcards are parsed,
#   a single space in the template matches empty fields only)
# - For the "func" field, and if the "name" field is right, if the numbers of
#   contracted functions is equal or greater than in the template for all
#   angular momenta (i.e., one can select fewer functions for these types of bases)
def match_field(field, basis, template):
  if (template[field] == ''):
    return True
  elif (template[field] == ' '):
    return (basis[field] == '')
  if ((field == 'func') and
      (basis['name'].upper().startswith('ANO') or
       (basis['name'].upper() in ['ECP', 'PSD', 'RYDBERG'])
      )
     ):
    # Parse the "func" field into an array of integers,
    # for basis and template
    funcs_basis = []
    funcs_template = []
    # Catch wrong angular momentum labels (e.g. "j")
    if (any(i not in angs for i in re.sub(r'\d', '', template[field]))):
      print('The [contraction] field ({0}) contains an invalid label.'.format(Fore.RED + template[field] + Style.RESET_ALL))
      print('Valid labels are: {0}'.format(Fore.BLUE + ', '.join(angs) + Style.RESET_ALL))
      print()
      sys.exit(1)
    for i in angs:
      match = re.search(r'\d+(?={0})'.format(i), basis[field].lower())
      if match:
        funcs_basis.append(int(match.group()))
      else:
        funcs_basis.append(0)
      match = re.search(r'\d+(?={0})'.format(i), template[field].lower())
      if match:
        funcs_template.append(int(match.group()))
      else:
        funcs_template.append(0)
    return all(i >= j for i,j in zip(funcs_basis, funcs_template))
  else:
    # Compare with wildcard parsing
    return bool(re.match(translate(template[field].upper()), basis[field].upper()))

# Decide if a parsed basis set matches a template
# All fields have to match, so we can short-circuit
def match_basis(basis, template):
  for k in props:
    if (not match_field(k, basis, template)):
      return False
  return True

# Filter the list of basis aliases to keep only those that match the template
def filter_basis_alias(template):
  return [b for b in basis_alias if match_basis(parse_basis(b[0]), template)]

# Filter the list of basis sets to keep only those that match the template
def filter_basis_sets(template):
  return [b for b in basis_sets if match_basis(b, template)]

# Color a string according to whether or not it matches a template
# - If the string of the template is empty, don't color
# - If it matches (case insensitive), color green
# - If it doesn't, color red
def color(string, template):
  if ((template == '') or (string == '')):
    return string
  elif (template.upper() == string.upper()):
    return Fore.GREEN + string + Style.RESET_ALL
  else:
    return Fore.RED + string + Style.RESET_ALL

# Convert a parsed basis set into a basis label
# - Join the fields given by "props" with ".", and remove any trailing "."
# - Color according to the template, if given
def name_basis(basis, template=None):
  if (template is None):
    return '.'.join([basis[k] for k in props]).rstrip('.')
  else:
    return '.'.join([color(basis[k], template[k]) for k in props]).rstrip('.')

# so we can spare reading the library
try:
  template = parse_basis(sys.argv[1])
except IndexError:
  print(helptext)
  sys.exit(0)

try:
  MOLCAS = environ['MOLCAS']
except KeyError:
  MOLCAS = '.'

# Find out the latest modification date in the basis library
# in order to read/write the database file
basis_lib = abspath(join(MOLCAS, 'basis_library'))
basis_files = glob(join(basis_lib, '*'))
mtime = getmtime(basis_lib)
for f in basis_files:
  mtime = max(mtime, getmtime(f))

# Get a list of all basis sets and all aliases
# "basis_sets": list of parsed basis sets (dicts)
# "basis_alias": list of 2-tuples: (alias, basis set template)
try:
  # If the files are older than the database, read this
  db_file = join(MOLCAS, 'data', 'basis.db')
  if (mtime > getmtime(db_file)):
    raise
  with open(db_file, 'r') as f:
    basis_sets = json.loads(f.readline())
    basis_alias = json.loads(f.readline())
except:
  # Otherwise rebuild the database from the basis library files
  basis_sets = []
  basis_alias = []
  for basis in sorted(basis_files):
    try:
      with open(basis, 'r') as f:
        for line in f:
          match = atom_basis.match(line)
          if (match):
            basis_sets.append(parse_basis(match.group(1)))
    except IOError:
      pass
  with open(join(basis_lib, 'basis.tbl'), 'r') as f:
    for line in f:
      match = alias.match(line)
      if (match):
        basis_alias.append(match.group(1, 2))
  # Attempt to save the database for later use
  try:
    with open(db_file, 'w') as f:
      json.dump(basis_sets, f)
      f.write('\n')
      json.dump(basis_alias, f)
      f.write('\n')
  except:
    pass

# Print the location of the basis library
if (not compat):
  print(Fore.BLUE + 'Basis library:' + Style.RESET_ALL)
  print(basis_lib)

# Print the aliases and basis sets that match the template given as command-line argument
if (compat):
  fmt = '{0}              {1}'
else:
  fmt = '{0}   ->   {1}'
first = True
for alias, basis in filter_basis_alias(template):
  name1 = name_basis(parse_basis(alias), template)
  name2 = name_basis(parse_basis(basis), template)
  if (first):
    if (compat):
      print('#RECOMMENDED')
    else:
      print('')
      print(Fore.BLUE + 'Basis aliases for ' + Fore.GREEN + name_basis(template) + Style.RESET_ALL + '\n')
    first = False
  # Can't use width in format, because it doesn't work with ANSI codes
  name1 += (27 - len(name_basis(parse_basis(alias))))*' '
  print(fmt.format(name1, name2))
if (not first):
  first = True
for basis in filter_basis_sets(template):
  if (first):
    if (compat):
      print('#OTHER')
    else:
      print('')
      print(Fore.BLUE + 'Basis sets for ' + Fore.GREEN + name_basis(template) + Style.RESET_ALL + '\n')
    first = False
  print(name_basis(basis, template))
if (first):
  print('')
  print('No basis sets found matching the query: {0}'.format(Fore.RED + sys.argv[1] + Style.RESET_ALL))
print('')
