##---------------------------------------------------------------------------- ## Name: is_surf_ann.py - ## Purpose: Train and run a feed forward ANN as an inverse model ## for the surf_is.py infrasound surf generation model ## ## Author: J Park ## ## Created: 1/17/07 ## ## Modified: ##---------------------------------------------------------------------------- import getopt from time import clock # ANN is based on ffnet-0.5 (http://ffnet.sourceforge.net) # ffnet depends on: # 1. numpy/scipy (http://scipy.org/) # 2. NetworkX (https://networkx.lanl.gov/wiki) # 3. matplotlib (http://matplotlib.sourceforge.net/) try: from ffnet import ffnet, mlgraph, savenet, loadnet except ImportError: print "Failed to import ffnet." try: from pylab import * except ImportError: print "Failed to import pylab for matplotlib." from scipy.interpolate.fitpack2 import UnivariateSpline #-------------------------------------------------- # Global Parameters DEBUG = False # -d TrainNetwork = False # -t RunNetwork = False # -r TestNetwork = False # -x testPrintLevel = 1 Plot = False # -p PlotData = False PlotNetwork = False BackPropagation = True # -b GeneticAlgorithm = False # -g ConjugateGradient = False # -c BFGS = False # -f TNC = False # -n inputVectors = [] plotLegend = [] ann_out_momentum = [] ann_out_genetic = [] ann_out_cg = [] ann_out_bfgs = [] ann_out_tnc = [] ##---------------------------------------------------------------------------- ## Main module def main(): ParseCmdLine() if TrainNetwork: TrainNetworks() if RunNetwork: RunNetworks() if Plot: PlotResults() print "Normal Exit" ##---------------------------------------------------------------------------- ## def TrainNetworks(): global ann_out_momentum, ann_out_genetic, ann_out_cg global ann_out_bfgs, ann_out_tnc global inputVectors #-------------------------------------------------------------------- # Read in the training data from the surf_is.py TrainANN output file TrainANNFile = "surf_is_TrainANN.dat" fdTrainANN = OpenFile(TrainANNFile, 'r') trainBuffer = fdTrainANN.readlines() fdTrainANN.close() print "TrainNetworks() Read %d lines from %s" % ( len(trainBuffer), TrainANNFile ) #-------------------------------------------------------------------- # Extract the training data into vectors for the ANN I/O # Each line in the trainBuffer represents an Input or Output vector # The vectors are paired in order, one Input followed by it's Output # Input is the Infrasound Spectrum # Target is the Breaking Wave Spectrum targetVectors = [] numInputs = 0 numOutputs = 0 for i in range( len(trainBuffer) ): bufferLine = trainBuffer[i] inputVector = False targetVector = False if bufferLine.startswith('I') : inputVector = True elif bufferLine.startswith('O') : targetVector = True else: print "Error: TrainNetworks() Failed to find valid line in %s\n" % TrainANNFile sys.exit(1) # : was used to delimit the header from the data dataStartIndex = bufferLine.rfind(':') + 1 # take slice from the data start to end of the line dataBuffer = bufferLine[ dataStartIndex : ] # now decompose the line into words of data dataWords = dataBuffer.split() # convert to floating point data = [ float(s) for s in dataWords ] if inputVector: inputVectors.append( data ) if not numInputs: numInputs = len( data ) elif numInputs != len( data ): print "Error: TrainNetworks() Number of inputs doesn't match in line %d\n", i+1 sys.exit(1) elif targetVector: targetVectors.append( data ) if not numOutputs: numOutputs = len( data ) elif numOutputs != len( data ): print "Error: TrainNetworks() Number of outputs doesn't match in line %d\n", i+1 sys.exit(1) else: print "Error: TrainNetworks() Invalid state of in/outputVector" sys.exit(1) print "TrainNetworks() numInputs=", numInputs print "TrainNetworks() numOutputs=", numOutputs #-------------------------------------------------------------------- # Create multilayer network with full connectivity and bias nodes conec = mlgraph( (numInputs, numInputs + 10 , numOutputs + 5, numOutputs), biases = True ) # Create ANN ann = ffnet(conec) #-------------------------------------------------------------------- # Train networks if BackPropagation: ## Simple backpropagation training with momentum. print "-> train_momentum", startTime = clock() ann.train_momentum (inputVectors, targetVectors, maxiter = 20000, disp = 0) momentum_dt = clock() - startTime print "<- time ", momentum_dt savenet(ann, "ann_momentum.net") ann_out_momentum = [ ann([x]) for x in inputVectors ] if TestNetwork: print print "Testing BackPropagation..." output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel) if GeneticAlgorithm: ## Global weights optimization with genetic algorithm. ## For more info see pikaia homepage and documentation: ## http://www.hao.ucar.edu/Public/models/pikaia/pikaia.html ann.randomweights() print "-> train_genetic ", startTime = clock() ann.train_genetic (inputVectors, targetVectors, individuals=20, generations=500) genetic_dt = clock() - startTime print "<- time ", genetic_dt savenet(ann, "ann_genetic.net") ann_out_genetic = [ ann([x]) for x in inputVectors ] if TestNetwork: print print "Testing GeneticAlgorithm..." output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel) if ConjugateGradient: ## Train network with conjugate gradient algorithm using the ## nonlinear conjugate gradient algorithm of Polak and Ribiere. ## See Wright, and Nocedal 'Numerical Optimization', 1999, pg. 120-122. kwargs = { "maxiter":20000 , "disp":0 } ann.randomweights() print "-> train_cg ", startTime = clock() ann.train_cg (inputVectors, targetVectors, **kwargs) cg_dt = clock() - startTime print "<- time ", cg_dt savenet(ann, "ann_cg.net") ann_out_cg = [ ann([x]) for x in inputVectors ] if TestNetwork: print print "Testing ConjugateGradient..." output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel) if BFGS: ## Train network with constrained version of quasi-Newton method ## of Broyden, Fletcher, Goldfarb, and Shanno (BFGS). kwargs = { "factr":1e7 , "pgtol":1e-5 , "maxfun":15000, "iprint":-1 } ann.randomweights() print "-> train_bfgs ", startTime = clock() ann.train_bfgs (inputVectors, targetVectors, **kwargs) bfgs_dt = clock() - startTime print "<- time ", bfgs_dt savenet(ann, "ann_bfgs.net") ann_out_bfgs = [ ann([x]) for x in inputVectors ] if TestNetwork: print print "Testing BFGS..." output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel) if TNC: ## Train network with a TNC algorithm. ## TNC is a C implementation of TNBC, a truncated newton ## optimization package originally developed by Stephen G. Nash in Fortran. print "-> train_tnc ", startTime = clock() ann.train_tnc (inputVectors, targetVectors, maxfun = 5000, messages=0) tnc_dt = clock() - startTime print "<- time ", tnc_dt savenet(ann, "ann_tnc.net") ann_out_tnc = [ ann([x]) for x in inputVectors ] if TestNetwork: print print "Testing TNC..." output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel) ##---------------------------------------------------------------------------- ## def RunNetworks(): global ann_out_momentum, ann_out_genetic, ann_out_cg global ann_out_bfgs, ann_out_tnc global inputVectors, plotLegend #-------------------------------------------------------------------- # Read in the input data from the surf_is.py RunANN output file RunANNFile = "surf_is_RunANN.dat" fdRunANN = OpenFile(RunANNFile, 'r') runBuffer = fdRunANN.readlines() fdRunANN.close() print "RunNetworks() Read %d lines from %s" % ( len(runBuffer), RunANNFile ) #-------------------------------------------------------------------- # Extract the data into vectors for ANN Input numInputs = 0 for i in range( len(runBuffer) ): bufferLine = runBuffer[i] if not bufferLine.startswith('I') : print "Error: RunNetworks() Failed to find valid line in %s\n" % RunANNFile sys.exit(1) # get the wind and SWH windStartIndex = bufferLine.find('w') windEndIndex = bufferLine.find(':', windStartIndex) windSpeed = bufferLine[ windStartIndex : windEndIndex ] SWHStartIndex = bufferLine.find('SWH') SWHEndIndex = bufferLine.find(':', SWHStartIndex) SWH = bufferLine[ SWHStartIndex : SWHEndIndex ] plotLegend.append( windSpeed + " " + SWH ) # : was used to delimit the header from the data dataStartIndex = bufferLine.rfind(':') + 1 # take slice from the data start to end of the line dataBuffer = bufferLine[ dataStartIndex : ] # now decompose the line into words of data dataWords = dataBuffer.split() # convert to floating point data = [ float(s) for s in dataWords ] inputVectors.append( data ) if not numInputs: numInputs = len( data ) elif numInputs != len( data ): print "Error: RunNetworks() Number of inputs doesn't match in line %d\n", i+1 sys.exit(1) print "RunNetworks() numInputs=", numInputs #-------------------------------------------------------------------- # Apply the input data to the ANN's if BackPropagation: ann = loadnet( "ann_momentum.net" ) ann_out_momentum = [ ann([x]) for x in inputVectors ] if GeneticAlgorithm: ann = loadnet( "ann_genetic.net" ) ann_out_genetic = [ ann([x]) for x in inputVectors ] if ConjugateGradient: ann = loadnet( "ann_cg.net" ) ann_out_cg = [ ann([x]) for x in inputVectors ] if BFGS: ann = loadnet( "ann_bfgs.net" ) ann_out_bfgs = [ ann([x]) for x in inputVectors ] if TNC: ann = loadnet( "ann_tnc.net" ) ann_out_tnc = [ ann([x]) for x in inputVectors ] ##---------------------------------------------------------------------------- ## def PlotResults(): # the input (Infrasound Spectrum) abscissa has values from linspace(1., 20., 20) infrasoundX = linspace(1., 20., 20) # the output (Break Wave Spectrum) abscissa has values from linspace(0.035, 0.2, 20) breakHeightX = linspace(0.035, 0.2, 20) xs = linspace(0.035, 0.2, 150) # subplot( nRows, nColumns, nFigure ) nRows = 1 nCols = 1 nFigure = 1 if BackPropagation: nRows += 1 if ConjugateGradient: nRows += 1 if BFGS: nRows += 1 if TNC: nRows += 1 if nRows == 4: nCols = 2 nRows = 2 elif nRows == 5: nCols = 2 nRows = 3 #figure(1) subplot( nRows, nCols, nFigure ) for i in range( len( inputVectors ) ): plot(infrasoundX, inputVectors[i], linewidth=2) title('Acoustic Spectrum ANN Inputs') ylabel( "dB" ) if BackPropagation: #figure(2) nFigure += 1 subplot( nRows, nCols, nFigure ) for i in range( len( ann_out_momentum ) ): # cubic spline of output spectrum spline = UnivariateSpline( breakHeightX, ann_out_momentum[i], s=0.01 ) ys = spline( xs ) plot( xs, maximum( 0., ys ), linewidth=2) # raw output spectrum if PlotData: plot(breakHeightX, ann_out_momentum[i], linewidth=2) title('Breaking Wave Spectrum BackPropagation ANN') if nCols == 1: xlabel( "freq (Hz)" ) ylabel( "ft" ) legend( plotLegend ) xticks( [0.02, 0.08, 0.14, 0.2] ) if ConjugateGradient: #figure(3) nFigure += 1 subplot( nRows, nCols, nFigure ) for i in range( len( ann_out_cg ) ): # cubic spline of output spectrum spline = UnivariateSpline( breakHeightX, ann_out_cg[i], s=0.01 ) ys = spline( xs ) plot( xs, maximum( 0., ys ), linewidth=2) # raw output spectrum if PlotData: plot(breakHeightX, ann_out_cg[i], linewidth=2) title('Breaking Wave Spectrum ConjugateGradient ANN') xlabel( "freq (Hz)" ) ylabel( "ft" ) legend( plotLegend ) xticks( [0.02, 0.08, 0.14, 0.2] ) if BFGS: #figure(4) nFigure += 1 subplot( nRows, nCols, nFigure ) for i in range( len( ann_out_bfgs ) ): # cubic spline of output spectrum spline = UnivariateSpline( breakHeightX, ann_out_bfgs[i], s=0.01 ) ys = spline( xs ) plot( xs, maximum( 0., ys ), linewidth=2) # raw output spectrum if PlotData: plot(breakHeightX, ann_out_bfgs[i], linewidth=2) title('Hb BFGS ANN') xlabel( "freq (Hz)" ) ylabel( "ft" ) legend( plotLegend ) xticks( [0.02, 0.08, 0.14, 0.2] ) if TNC: #figure(5) nFigure += 1 subplot( nRows, nCols, nFigure ) for i in range( len( ann_out_tnc ) ): # cubic spline of output spectrum spline = UnivariateSpline( breakHeightX, ann_out_tnc[i], s=0.01 ) ys = spline( xs ) plot( xs, maximum( 0., ys ), linewidth=2) # raw output spectrum if PlotData: plot(breakHeightX, ann_out_tnc[i], linewidth=2) title('Hb TNC ANN') xlabel( "freq (Hz)" ) ylabel( "ft" ) legend( plotLegend ) xticks( [0.02, 0.08, 0.14, 0.2] ) if TrainNetwork: plotName = "is_surf_Train" if RunNetwork: plotName = "is_surf_Run" # savefig will output to these formats: # eps, jpeg, pdf, png, ps, svg # the format is deduced by the filename extension (.png is default) savefig( plotName + '.ps' ) if PlotNetwork: try: import networkx nFigure += 1 figure( nFigure ) networkx.draw_circular(ann.graph) # can only call show once! show() except ImportError: print "Failed to import networkx." # can only call show once! show() ##---------------------------------------------------------------------------- def usage(): print "Usage: ", sys.argv[0] print "\t -h (help)" print "\t -d (debug)" print "\t -t (train network)" print "\t -r (run network)" print "\t -x (test network)" print "\t -p (plot)" print "\t -b (BackPropagation default)" print "\t -g (GeneticAlgorithm)" print "\t -c (ConjugateGradient)" print "\t -f (BFGS)" print "\t -n (TNC)" ##---------------------------------------------------------------------------- def ParseCmdLine(): ## global DEBUG, TrainNetwork, RunNetwork, TestNetwork, Plot global BackPropagation, GeneticAlgorithm, ConjugateGradient, BFGS, TNC try: opts, args = getopt.getopt(sys.argv[1:], "hdtrxpbgcfn") except getopt.GetoptError, error: print "ParseCmdLine(): GetoptError: ",error return if len(opts) == 0: usage() return argCount = 0 for optElement in opts: argCount = argCount + 1 if optElement[0] == "-h": usage() sys.exit(1) if optElement[0] == "-d": DEBUG = True if optElement[0] == "-t": TrainNetwork = True if optElement[0] == "-r": RunNetwork = True if optElement[0] == "-x": TestNetwork = True if optElement[0] == "-p": Plot = True if optElement[0] == "-b": BackPropagation = True if optElement[0] == "-g": GeneticAlgorithm = True if optElement[0] == "-c": ConjugateGradient = True if optElement[0] == "-f": BFGS = True if optElement[0] == "-n": TNC = True return ##---------------------------------------------------------------------------- def OpenFile(file, mode): try: # open() will raise an IOError exception if it fails print "Opening file:", file fd = open(file,mode) except IOError, error: print "OpenFile(): I/O error:", error sys.exit("OpenFile()") except OSError, error: print "OpenFile(): OS error:", error sys.exit("OpenFile()") except: print "OpenFile(): Unexpected error:", sys.exc_info()[0] raise sys.exit("OpenFile()") else: # executed after the try if no exceptions return fd ##---------------------------------------------------------------------------- ## Provide for cmd line invocation if __name__ == "__main__": main()