##----------------------------------------------------------------------------
## Name: is_surf_ann.py -
## Purpose: Train and run a feed forward ANN as an inverse model
## for the surf_is.py infrasound surf generation model
##
## Author: J Park
##
## Created: 1/17/07
##
## Modified:
##----------------------------------------------------------------------------
import getopt
from time import clock
# ANN is based on ffnet-0.5 (http://ffnet.sourceforge.net)
# ffnet depends on:
# 1. numpy/scipy (http://scipy.org/)
# 2. NetworkX (https://networkx.lanl.gov/wiki)
# 3. matplotlib (http://matplotlib.sourceforge.net/)
try:
from ffnet import ffnet, mlgraph, savenet, loadnet
except ImportError:
print "Failed to import ffnet."
try:
from pylab import *
except ImportError:
print "Failed to import pylab for matplotlib."
from scipy.interpolate.fitpack2 import UnivariateSpline
#--------------------------------------------------
# Global Parameters
DEBUG = False # -d
TrainNetwork = False # -t
RunNetwork = False # -r
TestNetwork = False # -x
testPrintLevel = 1
Plot = False # -p
PlotData = False
PlotNetwork = False
BackPropagation = True # -b
GeneticAlgorithm = False # -g
ConjugateGradient = False # -c
BFGS = False # -f
TNC = False # -n
inputVectors = []
plotLegend = []
ann_out_momentum = []
ann_out_genetic = []
ann_out_cg = []
ann_out_bfgs = []
ann_out_tnc = []
##----------------------------------------------------------------------------
## Main module
def main():
ParseCmdLine()
if TrainNetwork:
TrainNetworks()
if RunNetwork:
RunNetworks()
if Plot:
PlotResults()
print "Normal Exit"
##----------------------------------------------------------------------------
##
def TrainNetworks():
global ann_out_momentum, ann_out_genetic, ann_out_cg
global ann_out_bfgs, ann_out_tnc
global inputVectors
#--------------------------------------------------------------------
# Read in the training data from the surf_is.py TrainANN output file
TrainANNFile = "surf_is_TrainANN.dat"
fdTrainANN = OpenFile(TrainANNFile, 'r')
trainBuffer = fdTrainANN.readlines()
fdTrainANN.close()
print "TrainNetworks() Read %d lines from %s" % ( len(trainBuffer), TrainANNFile )
#--------------------------------------------------------------------
# Extract the training data into vectors for the ANN I/O
# Each line in the trainBuffer represents an Input or Output vector
# The vectors are paired in order, one Input followed by it's Output
# Input is the Infrasound Spectrum
# Target is the Breaking Wave Spectrum
targetVectors = []
numInputs = 0
numOutputs = 0
for i in range( len(trainBuffer) ):
bufferLine = trainBuffer[i]
inputVector = False
targetVector = False
if bufferLine.startswith('I') :
inputVector = True
elif bufferLine.startswith('O') :
targetVector = True
else:
print "Error: TrainNetworks() Failed to find valid line in %s\n" % TrainANNFile
sys.exit(1)
# : was used to delimit the header from the data
dataStartIndex = bufferLine.rfind(':') + 1
# take slice from the data start to end of the line
dataBuffer = bufferLine[ dataStartIndex : ]
# now decompose the line into words of data
dataWords = dataBuffer.split()
# convert to floating point
data = [ float(s) for s in dataWords ]
if inputVector:
inputVectors.append( data )
if not numInputs: numInputs = len( data )
elif numInputs != len( data ):
print "Error: TrainNetworks() Number of inputs doesn't match in line %d\n", i+1
sys.exit(1)
elif targetVector:
targetVectors.append( data )
if not numOutputs: numOutputs = len( data )
elif numOutputs != len( data ):
print "Error: TrainNetworks() Number of outputs doesn't match in line %d\n", i+1
sys.exit(1)
else:
print "Error: TrainNetworks() Invalid state of in/outputVector"
sys.exit(1)
print "TrainNetworks() numInputs=", numInputs
print "TrainNetworks() numOutputs=", numOutputs
#--------------------------------------------------------------------
# Create multilayer network with full connectivity and bias nodes
conec = mlgraph( (numInputs, numInputs + 10 , numOutputs + 5, numOutputs), biases = True )
# Create ANN
ann = ffnet(conec)
#--------------------------------------------------------------------
# Train networks
if BackPropagation:
## Simple backpropagation training with momentum.
print "-> train_momentum",
startTime = clock()
ann.train_momentum (inputVectors, targetVectors, maxiter = 20000, disp = 0)
momentum_dt = clock() - startTime
print "<- time ", momentum_dt
savenet(ann, "ann_momentum.net")
ann_out_momentum = [ ann([x]) for x in inputVectors ]
if TestNetwork:
print
print "Testing BackPropagation..."
output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
if GeneticAlgorithm:
## Global weights optimization with genetic algorithm.
## For more info see pikaia homepage and documentation:
## http://www.hao.ucar.edu/Public/models/pikaia/pikaia.html
ann.randomweights()
print "-> train_genetic ",
startTime = clock()
ann.train_genetic (inputVectors, targetVectors, individuals=20, generations=500)
genetic_dt = clock() - startTime
print "<- time ", genetic_dt
savenet(ann, "ann_genetic.net")
ann_out_genetic = [ ann([x]) for x in inputVectors ]
if TestNetwork:
print
print "Testing GeneticAlgorithm..."
output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
if ConjugateGradient:
## Train network with conjugate gradient algorithm using the
## nonlinear conjugate gradient algorithm of Polak and Ribiere.
## See Wright, and Nocedal 'Numerical Optimization', 1999, pg. 120-122.
kwargs = { "maxiter":20000 , "disp":0 }
ann.randomweights()
print "-> train_cg ",
startTime = clock()
ann.train_cg (inputVectors, targetVectors, **kwargs)
cg_dt = clock() - startTime
print "<- time ", cg_dt
savenet(ann, "ann_cg.net")
ann_out_cg = [ ann([x]) for x in inputVectors ]
if TestNetwork:
print
print "Testing ConjugateGradient..."
output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
if BFGS:
## Train network with constrained version of quasi-Newton method
## of Broyden, Fletcher, Goldfarb, and Shanno (BFGS).
kwargs = { "factr":1e7 , "pgtol":1e-5 , "maxfun":15000, "iprint":-1 }
ann.randomweights()
print "-> train_bfgs ",
startTime = clock()
ann.train_bfgs (inputVectors, targetVectors, **kwargs)
bfgs_dt = clock() - startTime
print "<- time ", bfgs_dt
savenet(ann, "ann_bfgs.net")
ann_out_bfgs = [ ann([x]) for x in inputVectors ]
if TestNetwork:
print
print "Testing BFGS..."
output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
if TNC:
## Train network with a TNC algorithm.
## TNC is a C implementation of TNBC, a truncated newton
## optimization package originally developed by Stephen G. Nash in Fortran.
print "-> train_tnc ",
startTime = clock()
ann.train_tnc (inputVectors, targetVectors, maxfun = 5000, messages=0)
tnc_dt = clock() - startTime
print "<- time ", tnc_dt
savenet(ann, "ann_tnc.net")
ann_out_tnc = [ ann([x]) for x in inputVectors ]
if TestNetwork:
print
print "Testing TNC..."
output, regression = ann.test(inputVectors, targetVectors, iprint = testPrintLevel)
##----------------------------------------------------------------------------
##
def RunNetworks():
global ann_out_momentum, ann_out_genetic, ann_out_cg
global ann_out_bfgs, ann_out_tnc
global inputVectors, plotLegend
#--------------------------------------------------------------------
# Read in the input data from the surf_is.py RunANN output file
RunANNFile = "surf_is_RunANN.dat"
fdRunANN = OpenFile(RunANNFile, 'r')
runBuffer = fdRunANN.readlines()
fdRunANN.close()
print "RunNetworks() Read %d lines from %s" % ( len(runBuffer), RunANNFile )
#--------------------------------------------------------------------
# Extract the data into vectors for ANN Input
numInputs = 0
for i in range( len(runBuffer) ):
bufferLine = runBuffer[i]
if not bufferLine.startswith('I') :
print "Error: RunNetworks() Failed to find valid line in %s\n" % RunANNFile
sys.exit(1)
# get the wind and SWH
windStartIndex = bufferLine.find('w')
windEndIndex = bufferLine.find(':', windStartIndex)
windSpeed = bufferLine[ windStartIndex : windEndIndex ]
SWHStartIndex = bufferLine.find('SWH')
SWHEndIndex = bufferLine.find(':', SWHStartIndex)
SWH = bufferLine[ SWHStartIndex : SWHEndIndex ]
plotLegend.append( windSpeed + " " + SWH )
# : was used to delimit the header from the data
dataStartIndex = bufferLine.rfind(':') + 1
# take slice from the data start to end of the line
dataBuffer = bufferLine[ dataStartIndex : ]
# now decompose the line into words of data
dataWords = dataBuffer.split()
# convert to floating point
data = [ float(s) for s in dataWords ]
inputVectors.append( data )
if not numInputs: numInputs = len( data )
elif numInputs != len( data ):
print "Error: RunNetworks() Number of inputs doesn't match in line %d\n", i+1
sys.exit(1)
print "RunNetworks() numInputs=", numInputs
#--------------------------------------------------------------------
# Apply the input data to the ANN's
if BackPropagation:
ann = loadnet( "ann_momentum.net" )
ann_out_momentum = [ ann([x]) for x in inputVectors ]
if GeneticAlgorithm:
ann = loadnet( "ann_genetic.net" )
ann_out_genetic = [ ann([x]) for x in inputVectors ]
if ConjugateGradient:
ann = loadnet( "ann_cg.net" )
ann_out_cg = [ ann([x]) for x in inputVectors ]
if BFGS:
ann = loadnet( "ann_bfgs.net" )
ann_out_bfgs = [ ann([x]) for x in inputVectors ]
if TNC:
ann = loadnet( "ann_tnc.net" )
ann_out_tnc = [ ann([x]) for x in inputVectors ]
##----------------------------------------------------------------------------
##
def PlotResults():
# the input (Infrasound Spectrum) abscissa has values from linspace(1., 20., 20)
infrasoundX = linspace(1., 20., 20)
# the output (Break Wave Spectrum) abscissa has values from linspace(0.035, 0.2, 20)
breakHeightX = linspace(0.035, 0.2, 20)
xs = linspace(0.035, 0.2, 150)
# subplot( nRows, nColumns, nFigure )
nRows = 1
nCols = 1
nFigure = 1
if BackPropagation:
nRows += 1
if ConjugateGradient:
nRows += 1
if BFGS:
nRows += 1
if TNC:
nRows += 1
if nRows == 4:
nCols = 2
nRows = 2
elif nRows == 5:
nCols = 2
nRows = 3
#figure(1)
subplot( nRows, nCols, nFigure )
for i in range( len( inputVectors ) ):
plot(infrasoundX, inputVectors[i], linewidth=2)
title('Acoustic Spectrum ANN Inputs')
ylabel( "dB" )
if BackPropagation:
#figure(2)
nFigure += 1
subplot( nRows, nCols, nFigure )
for i in range( len( ann_out_momentum ) ):
# cubic spline of output spectrum
spline = UnivariateSpline( breakHeightX, ann_out_momentum[i], s=0.01 )
ys = spline( xs )
plot( xs, maximum( 0., ys ), linewidth=2)
# raw output spectrum
if PlotData:
plot(breakHeightX, ann_out_momentum[i], linewidth=2)
title('Breaking Wave Spectrum BackPropagation ANN')
if nCols == 1: xlabel( "freq (Hz)" )
ylabel( "ft" )
legend( plotLegend )
xticks( [0.02, 0.08, 0.14, 0.2] )
if ConjugateGradient:
#figure(3)
nFigure += 1
subplot( nRows, nCols, nFigure )
for i in range( len( ann_out_cg ) ):
# cubic spline of output spectrum
spline = UnivariateSpline( breakHeightX, ann_out_cg[i], s=0.01 )
ys = spline( xs )
plot( xs, maximum( 0., ys ), linewidth=2)
# raw output spectrum
if PlotData:
plot(breakHeightX, ann_out_cg[i], linewidth=2)
title('Breaking Wave Spectrum ConjugateGradient ANN')
xlabel( "freq (Hz)" )
ylabel( "ft" )
legend( plotLegend )
xticks( [0.02, 0.08, 0.14, 0.2] )
if BFGS:
#figure(4)
nFigure += 1
subplot( nRows, nCols, nFigure )
for i in range( len( ann_out_bfgs ) ):
# cubic spline of output spectrum
spline = UnivariateSpline( breakHeightX, ann_out_bfgs[i], s=0.01 )
ys = spline( xs )
plot( xs, maximum( 0., ys ), linewidth=2)
# raw output spectrum
if PlotData:
plot(breakHeightX, ann_out_bfgs[i], linewidth=2)
title('Hb BFGS ANN')
xlabel( "freq (Hz)" )
ylabel( "ft" )
legend( plotLegend )
xticks( [0.02, 0.08, 0.14, 0.2] )
if TNC:
#figure(5)
nFigure += 1
subplot( nRows, nCols, nFigure )
for i in range( len( ann_out_tnc ) ):
# cubic spline of output spectrum
spline = UnivariateSpline( breakHeightX, ann_out_tnc[i], s=0.01 )
ys = spline( xs )
plot( xs, maximum( 0., ys ), linewidth=2)
# raw output spectrum
if PlotData:
plot(breakHeightX, ann_out_tnc[i], linewidth=2)
title('Hb TNC ANN')
xlabel( "freq (Hz)" )
ylabel( "ft" )
legend( plotLegend )
xticks( [0.02, 0.08, 0.14, 0.2] )
if TrainNetwork:
plotName = "is_surf_Train"
if RunNetwork:
plotName = "is_surf_Run"
# savefig will output to these formats:
# eps, jpeg, pdf, png, ps, svg
# the format is deduced by the filename extension (.png is default)
savefig( plotName + '.ps' )
if PlotNetwork:
try:
import networkx
nFigure += 1
figure( nFigure )
networkx.draw_circular(ann.graph)
# can only call show once!
show()
except ImportError:
print "Failed to import networkx."
# can only call show once!
show()
##----------------------------------------------------------------------------
def usage():
print "Usage: ", sys.argv[0]
print "\t -h (help)"
print "\t -d (debug)"
print "\t -t (train network)"
print "\t -r (run network)"
print "\t -x (test network)"
print "\t -p (plot)"
print "\t -b (BackPropagation default)"
print "\t -g (GeneticAlgorithm)"
print "\t -c (ConjugateGradient)"
print "\t -f (BFGS)"
print "\t -n (TNC)"
##----------------------------------------------------------------------------
def ParseCmdLine():
##
global DEBUG, TrainNetwork, RunNetwork, TestNetwork, Plot
global BackPropagation, GeneticAlgorithm, ConjugateGradient, BFGS, TNC
try:
opts, args = getopt.getopt(sys.argv[1:], "hdtrxpbgcfn")
except getopt.GetoptError, error:
print "ParseCmdLine(): GetoptError: ",error
return
if len(opts) == 0:
usage()
return
argCount = 0
for optElement in opts:
argCount = argCount + 1
if optElement[0] == "-h":
usage()
sys.exit(1)
if optElement[0] == "-d":
DEBUG = True
if optElement[0] == "-t":
TrainNetwork = True
if optElement[0] == "-r":
RunNetwork = True
if optElement[0] == "-x":
TestNetwork = True
if optElement[0] == "-p":
Plot = True
if optElement[0] == "-b":
BackPropagation = True
if optElement[0] == "-g":
GeneticAlgorithm = True
if optElement[0] == "-c":
ConjugateGradient = True
if optElement[0] == "-f":
BFGS = True
if optElement[0] == "-n":
TNC = True
return
##----------------------------------------------------------------------------
def OpenFile(file, mode):
try:
# open() will raise an IOError exception if it fails
print "Opening file:", file
fd = open(file,mode)
except IOError, error:
print "OpenFile(): I/O error:", error
sys.exit("OpenFile()")
except OSError, error:
print "OpenFile(): OS error:", error
sys.exit("OpenFile()")
except:
print "OpenFile(): Unexpected error:", sys.exc_info()[0]
raise
sys.exit("OpenFile()")
else: # executed after the try if no exceptions
return fd
##----------------------------------------------------------------------------
## Provide for cmd line invocation
if __name__ == "__main__":
main()