##----------------------------------------------------------------------------
## Name:         is_surf_ann.py -
## Purpose:      Train and run a feed forward ANN as an inverse model
##               for the surf_is.py infrasound surf generation model
##               
## Author:       J Park
##
## Created:      1/17/07
##
## Modified:     
##----------------------------------------------------------------------------

import getopt
from time import clock
# ANN is based on ffnet-0.5 (http://ffnet.sourceforge.net)
# ffnet depends on:
# 1. numpy/scipy (http://scipy.org/)
# 2. NetworkX (https://networkx.lanl.gov/wiki)
# 3. matplotlib (http://matplotlib.sourceforge.net/)
try:
    from ffnet import ffnet, mlgraph, savenet, loadnet
except ImportError: 
    print "Failed to import ffnet."
try:
    from pylab import *
except ImportError: 
    print "Failed to import pylab for matplotlib."

from scipy.interpolate.fitpack2 import UnivariateSpline

#--------------------------------------------------
# Global Parameters
DEBUG             = False  # -d

TrainNetwork      = False  # -t
RunNetwork        = False  # -r
TestNetwork       = False  # -x
testPrintLevel    = 1
Plot              = False  # -p
PlotData          = False
PlotNetwork       = False

BackPropagation   = True   # -b
GeneticAlgorithm  = False  # -g
ConjugateGradient = False  # -c
BFGS              = False  # -f
TNC               = False  # -n


inputVectors     = []
plotLegend       = []
ann_out_momentum = []
ann_out_genetic  = []
ann_out_cg       = []
ann_out_bfgs     = []
ann_out_tnc      = []

##----------------------------------------------------------------------------
## Main module
def main():

    ParseCmdLine()

    if TrainNetwork:
        TrainNetworks()

    if RunNetwork:
        RunNetworks()

    if Plot:
        PlotResults()

    print "Normal Exit"

##----------------------------------------------------------------------------
## 
def TrainNetworks():

    global ann_out_momentum, ann_out_genetic, ann_out_cg
    global ann_out_bfgs, ann_out_tnc
    global inputVectors

    #--------------------------------------------------------------------
    # Read in the training data from the surf_is.py TrainANN output file
    TrainANNFile = "surf_is_TrainANN.dat"
    fdTrainANN   = OpenFile(TrainANNFile, 'r')
    trainBuffer  = fdTrainANN.readlines()
    fdTrainANN.close()
    print "TrainNetworks() Read %d lines from %s" % ( len(trainBuffer), TrainANNFile )

    #--------------------------------------------------------------------
    # Extract the training data into vectors for the ANN I/O
    # Each line in the trainBuffer represents an Input or Output vector
    # The vectors are paired in order, one Input followed by it's Output
    # Input is the Infrasound Spectrum
    # Target is the Breaking Wave Spectrum
    targetVectors = []
    numInputs  = 0
    numOutputs = 0
    
    for i in range( len(trainBuffer) ):
        bufferLine = trainBuffer[i]
        inputVector  = False
        targetVector = False
        if bufferLine.startswith('I') :
            inputVector = True
        elif bufferLine.startswith('O') :
            targetVector = True
        else:
            print "Error: TrainNetworks() Failed to find valid line in %s\n" % TrainANNFile
            sys.exit(1)

        # : was used to delimit the header from the data
        dataStartIndex = bufferLine.rfind(':') + 1
        # take slice from the data start to end of the line
        dataBuffer = bufferLine[ dataStartIndex : ]
        # now decompose the line into words of data
        dataWords = dataBuffer.split()
        # convert to floating point
        data = [ float(s) for s in dataWords ]

        if inputVector:
            inputVectors.append( data )
            if not numInputs: numInputs = len( data )
            elif numInputs != len( data ): 
                print "Error: TrainNetworks() Number of inputs doesn't match in line %d\n", i+1
                sys.exit(1)
        elif targetVector:
            targetVectors.append( data )
            if not numOutputs: numOutputs = len( data )
            elif numOutputs != len( data ):
                print "Error: TrainNetworks() Number of outputs doesn't match in line %d\n", i+1
                sys.exit(1)
        else:
            print "Error: TrainNetworks() Invalid state of in/outputVector"
            sys.exit(1)


    print "TrainNetworks() numInputs=",  numInputs
    print "TrainNetworks() numOutputs=", numOutputs

    #--------------------------------------------------------------------
    # Create multilayer network with full connectivity and bias nodes
    conec = mlgraph( (numInputs, numInputs + 10 , numOutputs + 5, numOutputs), biases = True )

    # Create ANN
    ann = ffnet(conec)

    #--------------------------------------------------------------------
    # Train networks
    if BackPropagation:
        ## Simple backpropagation training with momentum.
        print "-> train_momentum",
        startTime = clock()
        ann.train_momentum (inputVectors, targetVectors, maxiter = 20000, disp = 0)
        momentum_dt = clock() - startTime
        print "<- time ", momentum_dt
        savenet(ann, "ann_momentum.net")
        ann_out_momentum = [ ann([x]) for x in inputVectors ]
        if TestNetwork:
            print
            print "Testing BackPropagation..."
            output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
    
    if GeneticAlgorithm:
        ## Global weights optimization with genetic algorithm.
        ## For more info see pikaia homepage and documentation:
        ## http://www.hao.ucar.edu/Public/models/pikaia/pikaia.html
        ann.randomweights()
        print "-> train_genetic ",
        startTime = clock()
        ann.train_genetic (inputVectors, targetVectors, individuals=20, generations=500)
        genetic_dt = clock() - startTime
        print "<- time ", genetic_dt 
        savenet(ann, "ann_genetic.net")
        ann_out_genetic = [ ann([x]) for x in inputVectors ]
        if TestNetwork:
            print
            print "Testing GeneticAlgorithm..."
            output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
    
    if ConjugateGradient:
        ## Train network with conjugate gradient algorithm using the
        ## nonlinear conjugate gradient algorithm of Polak and Ribiere.
        ## See Wright, and Nocedal 'Numerical Optimization', 1999, pg. 120-122.
        kwargs = { "maxiter":20000 , "disp":0 }
        ann.randomweights()
        print "-> train_cg      ",
        startTime = clock()
        ann.train_cg (inputVectors, targetVectors, **kwargs)
        cg_dt = clock() - startTime
        print "<- time ", cg_dt
        savenet(ann, "ann_cg.net")
        ann_out_cg = [ ann([x]) for x in inputVectors ]
        if TestNetwork:
            print
            print "Testing ConjugateGradient..."
            output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
    
    if BFGS:
        ## Train network with constrained version of quasi-Newton method
        ## of Broyden, Fletcher, Goldfarb, and Shanno (BFGS).
        kwargs = { "factr":1e7 , "pgtol":1e-5 , "maxfun":15000, "iprint":-1 }
        ann.randomweights()
        print "-> train_bfgs    ",
        startTime = clock()
        ann.train_bfgs (inputVectors, targetVectors, **kwargs)
        bfgs_dt = clock() - startTime
        print "<- time ", bfgs_dt 
        savenet(ann, "ann_bfgs.net")
        ann_out_bfgs = [ ann([x]) for x in inputVectors ]
        if TestNetwork:
            print
            print "Testing BFGS..."
            output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
    
    if TNC:
        ## Train network with a TNC algorithm.
        ## TNC is a C implementation of TNBC, a truncated newton
        ## optimization package originally developed by Stephen G. Nash in Fortran.
        print "-> train_tnc     ",
        startTime = clock()
        ann.train_tnc (inputVectors, targetVectors, maxfun = 5000, messages=0)
        tnc_dt = clock() - startTime
        print "<- time ", tnc_dt 
        savenet(ann, "ann_tnc.net")
        ann_out_tnc = [ ann([x]) for x in inputVectors ]
        if TestNetwork:
            print
            print "Testing TNC..."
            output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)


##----------------------------------------------------------------------------
## 
def RunNetworks():
    
    global ann_out_momentum, ann_out_genetic, ann_out_cg
    global ann_out_bfgs, ann_out_tnc
    global inputVectors, plotLegend

    #--------------------------------------------------------------------
    # Read in the input data from the surf_is.py RunANN output file
    RunANNFile = "surf_is_RunANN.dat"
    fdRunANN   = OpenFile(RunANNFile, 'r')
    runBuffer  = fdRunANN.readlines()
    fdRunANN.close()
    print "RunNetworks() Read %d lines from %s" % ( len(runBuffer), RunANNFile )

    #--------------------------------------------------------------------
    # Extract the data into vectors for ANN Input
    numInputs = 0
    
    for i in range( len(runBuffer) ):
        bufferLine = runBuffer[i]
        if not bufferLine.startswith('I') :
            print "Error: RunNetworks() Failed to find valid line in %s\n" % RunANNFile
            sys.exit(1)

        # get the wind and SWH
        windStartIndex = bufferLine.find('w')
        windEndIndex   = bufferLine.find(':', windStartIndex)
        windSpeed      = bufferLine[ windStartIndex : windEndIndex ]
        SWHStartIndex  = bufferLine.find('SWH')
        SWHEndIndex    = bufferLine.find(':', SWHStartIndex)
        SWH            = bufferLine[ SWHStartIndex : SWHEndIndex ]
        plotLegend.append( windSpeed + " " + SWH )
        
        # : was used to delimit the header from the data
        dataStartIndex = bufferLine.rfind(':') + 1
        # take slice from the data start to end of the line
        dataBuffer = bufferLine[ dataStartIndex : ]
        # now decompose the line into words of data
        dataWords = dataBuffer.split()
        # convert to floating point
        data = [ float(s) for s in dataWords ]

        inputVectors.append( data )
        if not numInputs: numInputs = len( data )
        elif numInputs != len( data ): 
            print "Error: RunNetworks() Number of inputs doesn't match in line %d\n", i+1
            sys.exit(1)

    print "RunNetworks() numInputs=", numInputs
    
    #--------------------------------------------------------------------
    # Apply the input data to the ANN's
    if BackPropagation:
        ann = loadnet( "ann_momentum.net" )
        ann_out_momentum = [ ann([x]) for x in inputVectors ]
        
    if GeneticAlgorithm:
        ann = loadnet( "ann_genetic.net" )
        ann_out_genetic = [ ann([x]) for x in inputVectors ]

    if ConjugateGradient:
        ann = loadnet( "ann_cg.net" )
        ann_out_cg = [ ann([x]) for x in inputVectors ]

    if BFGS:
        ann = loadnet( "ann_bfgs.net" )
        ann_out_bfgs = [ ann([x]) for x in inputVectors ]

    if TNC:
        ann = loadnet( "ann_tnc.net" )
        ann_out_tnc = [ ann([x]) for x in inputVectors ]


##----------------------------------------------------------------------------
## 
def PlotResults():

    # the input (Infrasound Spectrum) abscissa has values from linspace(1., 20., 20)
    infrasoundX = linspace(1., 20., 20)
    # the output (Break Wave Spectrum) abscissa has values from linspace(0.035, 0.2, 20)
    breakHeightX = linspace(0.035, 0.2, 20)
    xs           = linspace(0.035, 0.2, 150)

    # subplot( nRows, nColumns, nFigure )
    nRows   = 1
    nCols   = 1
    nFigure = 1

    if BackPropagation:
        nRows += 1
    if ConjugateGradient:
        nRows += 1
    if BFGS:
        nRows += 1
    if TNC:
        nRows += 1

    if nRows == 4:
        nCols = 2
        nRows = 2
    elif nRows == 5:
        nCols = 2
        nRows = 3

    #figure(1)
    subplot( nRows, nCols, nFigure )
    for i in range( len( inputVectors ) ):
        plot(infrasoundX, inputVectors[i], linewidth=2)
    title('Acoustic Spectrum ANN Inputs')
    ylabel( "dB" )

    if BackPropagation:
        #figure(2)
        nFigure += 1
        subplot( nRows, nCols, nFigure )
        for i in range( len( ann_out_momentum ) ):
            # cubic spline of output spectrum
            spline = UnivariateSpline( breakHeightX, ann_out_momentum[i], s=0.01 )
            ys     = spline( xs )
            plot( xs, maximum( 0., ys ), linewidth=2)
            # raw output spectrum
            if PlotData:
                plot(breakHeightX, ann_out_momentum[i], linewidth=2)
        title('Breaking Wave Spectrum BackPropagation ANN')
        if nCols == 1: xlabel( "freq (Hz)" )
        ylabel( "ft" )
        legend( plotLegend ) 
        xticks( [0.02, 0.08, 0.14, 0.2] )

    if ConjugateGradient:
        #figure(3)
        nFigure += 1
        subplot( nRows, nCols, nFigure )
        for i in range( len( ann_out_cg ) ):
            # cubic spline of output spectrum
            spline = UnivariateSpline( breakHeightX, ann_out_cg[i], s=0.01 )
            ys     = spline( xs )
            plot( xs, maximum( 0., ys ), linewidth=2)
            # raw output spectrum
            if PlotData:
                plot(breakHeightX, ann_out_cg[i], linewidth=2)
        title('Breaking Wave Spectrum ConjugateGradient ANN')
        xlabel( "freq (Hz)" )
        ylabel( "ft" )
        legend( plotLegend ) 
        xticks( [0.02, 0.08, 0.14, 0.2] )

    if BFGS:
        #figure(4)
        nFigure += 1
        subplot( nRows, nCols, nFigure )
        for i in range( len( ann_out_bfgs ) ):
            # cubic spline of output spectrum
            spline = UnivariateSpline( breakHeightX, ann_out_bfgs[i], s=0.01 )
            ys     = spline( xs )
            plot( xs, maximum( 0., ys ), linewidth=2)
            # raw output spectrum
            if PlotData:
                plot(breakHeightX, ann_out_bfgs[i], linewidth=2)
        title('Hb  BFGS ANN')
        xlabel( "freq (Hz)" )
        ylabel( "ft" )
        legend( plotLegend ) 
        xticks( [0.02, 0.08, 0.14, 0.2] )

    if TNC:
        #figure(5)
        nFigure += 1
        subplot( nRows, nCols, nFigure )
        for i in range( len( ann_out_tnc ) ):
            # cubic spline of output spectrum
            spline = UnivariateSpline( breakHeightX, ann_out_tnc[i], s=0.01 )
            ys     = spline( xs )
            plot( xs, maximum( 0., ys ), linewidth=2)
            # raw output spectrum
            if PlotData:
                plot(breakHeightX, ann_out_tnc[i], linewidth=2)
        title('Hb TNC ANN')
        xlabel( "freq (Hz)" )
        ylabel( "ft" )
        legend( plotLegend ) 
        xticks( [0.02, 0.08, 0.14, 0.2] )

    if TrainNetwork:
        plotName = "is_surf_Train"
    if RunNetwork:
        plotName = "is_surf_Run"

    # savefig will output to these formats:
    # eps, jpeg, pdf, png, ps, svg
    # the format is deduced by the filename extension (.png is default)
    savefig( plotName + '.ps' )

    if PlotNetwork:
        try:
            import networkx 
            nFigure += 1
            figure( nFigure )
            networkx.draw_circular(ann.graph)
            # can only call show once!
            show()
        except ImportError: 
            print "Failed to import networkx."

    # can only call show once!
    show()

##----------------------------------------------------------------------------
def usage():
    print "Usage: ", sys.argv[0]
    print "\t -h (help)"
    print "\t -d (debug)"
    print "\t -t (train network)"
    print "\t -r (run network)"
    print "\t -x (test network)"
    print "\t -p (plot)"
    print "\t -b (BackPropagation default)"
    print "\t -g (GeneticAlgorithm)"
    print "\t -c (ConjugateGradient)"
    print "\t -f (BFGS)"
    print "\t -n (TNC)"

##----------------------------------------------------------------------------
def ParseCmdLine():
    ## 
    global DEBUG, TrainNetwork, RunNetwork, TestNetwork, Plot
    global BackPropagation, GeneticAlgorithm, ConjugateGradient, BFGS, TNC
    
    try:
	opts, args = getopt.getopt(sys.argv[1:], "hdtrxpbgcfn")

    except getopt.GetoptError, error:
	print "ParseCmdLine(): GetoptError: ",error
	return

    if len(opts) == 0:
        usage()
        return

    argCount = 0
    for optElement in opts:
        argCount = argCount + 1
        if optElement[0] == "-h":
            usage()
            sys.exit(1)
        if optElement[0] == "-d":
            DEBUG = True
        if optElement[0] == "-t":
            TrainNetwork = True
        if optElement[0] == "-r":
            RunNetwork = True
        if optElement[0] == "-x":
            TestNetwork = True
        if optElement[0] == "-p":
            Plot = True
        if optElement[0] == "-b":
            BackPropagation = True
        if optElement[0] == "-g":
            GeneticAlgorithm = True
        if optElement[0] == "-c":
            ConjugateGradient = True
        if optElement[0] == "-f":
            BFGS = True
        if optElement[0] == "-n":
            TNC = True

    return

##----------------------------------------------------------------------------
def OpenFile(file, mode):
    try:
        # open() will raise an IOError exception if it fails
        print "Opening file:", file
        fd = open(file,mode)

    except IOError, error:
        print "OpenFile(): I/O error:", error
        sys.exit("OpenFile()")

    except OSError, error:
        print "OpenFile(): OS error:", error
        sys.exit("OpenFile()")
        
    except:
        print "OpenFile(): Unexpected error:", sys.exc_info()[0]
        raise
        sys.exit("OpenFile()")

    else: # executed after the try if no exceptions
        return fd
    
##----------------------------------------------------------------------------
## Provide for cmd line invocation
if __name__ == "__main__":
    main()