diff options
Diffstat (limited to 'src/main/java/org/openslx/util/Util.java')
-rw-r--r-- | src/main/java/org/openslx/util/Util.java | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/src/main/java/org/openslx/util/Util.java b/src/main/java/org/openslx/util/Util.java index e425083..b8385f8 100644 --- a/src/main/java/org/openslx/util/Util.java +++ b/src/main/java/org/openslx/util/Util.java @@ -1,7 +1,18 @@ package org.openslx.util; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Pattern; +import javax.net.SocketFactory; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -165,4 +176,88 @@ public class Util return String.format( "%.1f %s", val, UNITS[unit] ); } + /** + * Connect to given host(name), trying all addresses it resolves to. + */ + public static Socket connectAllRecords( SocketFactory fac, String host, int port, int timeout ) throws IOException + { + InetAddress[] addrList; + addrList = InetAddress.getAllByName( host ); + if ( addrList.length == 0 ) { + throw new UnknownHostException( "Unknown host: " + host ); + } else if ( addrList.length == 1 ) { + // Simple case + Socket s = fac.createSocket(); + s.connect( new InetSocketAddress( addrList[0], port ), timeout ); + return s; + } + // Cascaded connects + log.debug( "Got " + addrList.length + " hosts for " + host ); + String name = host.length() > 12 ? host.substring( 0, 12 ) : host; + ThreadPoolExecutor tpe = new CascadedThreadPoolExecutor( Math.min( addrList.length, 4 ), 4, + 2, TimeUnit.SECONDS, + 4, new ThreadPoolExecutor.AbortPolicy(), name ); + final AtomicReference<IOException> fe = new AtomicReference<>(); + final AtomicReference<Socket> retSock = new AtomicReference<>(); + final Semaphore sem = new Semaphore( 0 ); + try { + int endIdx = addrList.length - 1; + for ( int idx = 0; idx <= endIdx; idx++ ) { + InetAddress addr = addrList[idx]; + // Create next connect task + Runnable task = new Runnable() { + @Override + public void run() + { + try { + Socket s = fac.createSocket(); + log.debug( "Trying " + addr.toString() ); + s.connect( new InetSocketAddress( addr, port ), timeout ); + if ( retSock.compareAndSet( null, s ) ) { + log.debug( addr.toString() + ": Success" ); + sem.release(); + } else { + // Lost race with another thread + log.debug( addr.toString() + ": Success, but lost race" ); + s.close(); + } + } catch ( IOException e ) { + fe.set( e ); + } + } + }; + try { + tpe.execute( task ); + } catch ( Exception e ) { + log.debug( "Failed to queue connect for " + addr.toString() ); + } + // Wait for semaphore, or 250ms timeout + boolean gotit = false; + log.debug( "Waiting for connect..." ); + try { + // Wait longer on last iteration + if ( idx == endIdx ) { + gotit = sem.tryAcquire( timeout + 10, TimeUnit.MILLISECONDS ); + } else { + gotit = sem.tryAcquire( 250, TimeUnit.MILLISECONDS ); + } + } catch ( InterruptedException e ) { + } + if ( gotit ) { + // Got semaphore, a task should have succeeded + sem.release(); + return retSock.getAndSet( null ); + } + } + log.debug( "Out of addresses" ); + sem.release(); + throw fe.get(); + } finally { + tpe.shutdownNow(); + // Make sure a connect that succeeded after we exit the loop + // but before we call sem.release() will be cleaned up + Util.safeClose( retSock.get() ); + } + } + } |