summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSimon Rettberg2024-06-27 14:36:10 +0200
committerSimon Rettberg2024-06-27 14:36:10 +0200
commit6f3cd3f67ba4a4508b0a7b4bf944ba9574f98445 (patch)
treea5abb47cc033bb48605694e21cf4e6124e5a3f23
parent[libvirt] Add getter/setter for os firmware (diff)
downloadmaster-sync-shared-6f3cd3f67ba4a4508b0a7b4bf944ba9574f98445.tar.gz
master-sync-shared-6f3cd3f67ba4a4508b0a7b4bf944ba9574f98445.tar.xz
master-sync-shared-6f3cd3f67ba4a4508b0a7b4bf944ba9574f98445.zip
[Util] Add socket connect helper to use all available A/AAAA records
(and use it)
-rw-r--r--src/main/java/org/openslx/filetransfer/Transfer.java17
-rw-r--r--src/main/java/org/openslx/thrifthelper/ThriftManager.java10
-rw-r--r--src/main/java/org/openslx/util/Util.java95
3 files changed, 104 insertions, 18 deletions
diff --git a/src/main/java/org/openslx/filetransfer/Transfer.java b/src/main/java/org/openslx/filetransfer/Transfer.java
index aebd3ce..c987576 100644
--- a/src/main/java/org/openslx/filetransfer/Transfer.java
+++ b/src/main/java/org/openslx/filetransfer/Transfer.java
@@ -4,21 +4,20 @@ import java.io.Closeable;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
-import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
+import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
-import javax.net.ssl.SSLSocketFactory;
-
-import net.jpountz.lz4.LZ4Factory;
import org.apache.logging.log4j.Logger;
import org.openslx.util.Util;
+import net.jpountz.lz4.LZ4Factory;
+
public abstract class Transfer
{
protected final Socket transferSocket;
@@ -45,14 +44,10 @@ public abstract class Transfer
{
this.log = log;
// create socket.
- if ( context == null ) {
- transferSocket = new Socket();
- } else {
- SSLSocketFactory sslSocketFactory = context.getSocketFactory();
- transferSocket = sslSocketFactory.createSocket();
- }
+ transferSocket = Util.connectAllRecords(
+ context == null ? SocketFactory.getDefault() : context.getSocketFactory(),
+ host, port, 4000 );
transferSocket.setSoTimeout( readTimeoutMs );
- transferSocket.connect( new InetSocketAddress( host, port ), 4000 );
outStream = new DataOutputStream( transferSocket.getOutputStream() );
dataFromServer = new DataInputStream( transferSocket.getInputStream() );
diff --git a/src/main/java/org/openslx/thrifthelper/ThriftManager.java b/src/main/java/org/openslx/thrifthelper/ThriftManager.java
index 07256b2..9bed5cd 100644
--- a/src/main/java/org/openslx/thrifthelper/ThriftManager.java
+++ b/src/main/java/org/openslx/thrifthelper/ThriftManager.java
@@ -2,7 +2,6 @@ package org.openslx.thrifthelper;
import java.io.IOException;
import java.lang.reflect.Proxy;
-import java.net.InetSocketAddress;
import java.net.Socket;
import javax.net.SocketFactory;
@@ -206,12 +205,9 @@ public class ThriftManager<T>
TSocket tsock;
Socket socket = null;
try {
- if ( ctx == null ) {
- socket = SocketFactory.getDefault().createSocket();
- } else {
- socket = ctx.getSocketFactory().createSocket();
- }
- socket.connect( new InetSocketAddress( host, port ), 4000 );
+ socket = Util.connectAllRecords(
+ ctx == null ? SocketFactory.getDefault() : ctx.getSocketFactory(),
+ host, port, 4000 );
socket.setSoTimeout( timeout );
} catch ( IOException e ) {
if ( socket != null ) {
diff --git a/src/main/java/org/openslx/util/Util.java b/src/main/java/org/openslx/util/Util.java
index e425083..b8385f8 100644
--- a/src/main/java/org/openslx/util/Util.java
+++ b/src/main/java/org/openslx/util/Util.java
@@ -1,7 +1,18 @@
package org.openslx.util;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.net.UnknownHostException;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Pattern;
+import javax.net.SocketFactory;
+
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -165,4 +176,88 @@ public class Util
return String.format( "%.1f %s", val, UNITS[unit] );
}
+ /**
+ * Connect to given host(name), trying all addresses it resolves to.
+ */
+ public static Socket connectAllRecords( SocketFactory fac, String host, int port, int timeout ) throws IOException
+ {
+ InetAddress[] addrList;
+ addrList = InetAddress.getAllByName( host );
+ if ( addrList.length == 0 ) {
+ throw new UnknownHostException( "Unknown host: " + host );
+ } else if ( addrList.length == 1 ) {
+ // Simple case
+ Socket s = fac.createSocket();
+ s.connect( new InetSocketAddress( addrList[0], port ), timeout );
+ return s;
+ }
+ // Cascaded connects
+ log.debug( "Got " + addrList.length + " hosts for " + host );
+ String name = host.length() > 12 ? host.substring( 0, 12 ) : host;
+ ThreadPoolExecutor tpe = new CascadedThreadPoolExecutor( Math.min( addrList.length, 4 ), 4,
+ 2, TimeUnit.SECONDS,
+ 4, new ThreadPoolExecutor.AbortPolicy(), name );
+ final AtomicReference<IOException> fe = new AtomicReference<>();
+ final AtomicReference<Socket> retSock = new AtomicReference<>();
+ final Semaphore sem = new Semaphore( 0 );
+ try {
+ int endIdx = addrList.length - 1;
+ for ( int idx = 0; idx <= endIdx; idx++ ) {
+ InetAddress addr = addrList[idx];
+ // Create next connect task
+ Runnable task = new Runnable() {
+ @Override
+ public void run()
+ {
+ try {
+ Socket s = fac.createSocket();
+ log.debug( "Trying " + addr.toString() );
+ s.connect( new InetSocketAddress( addr, port ), timeout );
+ if ( retSock.compareAndSet( null, s ) ) {
+ log.debug( addr.toString() + ": Success" );
+ sem.release();
+ } else {
+ // Lost race with another thread
+ log.debug( addr.toString() + ": Success, but lost race" );
+ s.close();
+ }
+ } catch ( IOException e ) {
+ fe.set( e );
+ }
+ }
+ };
+ try {
+ tpe.execute( task );
+ } catch ( Exception e ) {
+ log.debug( "Failed to queue connect for " + addr.toString() );
+ }
+ // Wait for semaphore, or 250ms timeout
+ boolean gotit = false;
+ log.debug( "Waiting for connect..." );
+ try {
+ // Wait longer on last iteration
+ if ( idx == endIdx ) {
+ gotit = sem.tryAcquire( timeout + 10, TimeUnit.MILLISECONDS );
+ } else {
+ gotit = sem.tryAcquire( 250, TimeUnit.MILLISECONDS );
+ }
+ } catch ( InterruptedException e ) {
+ }
+ if ( gotit ) {
+ // Got semaphore, a task should have succeeded
+ sem.release();
+ return retSock.getAndSet( null );
+ }
+ }
+ log.debug( "Out of addresses" );
+ sem.release();
+ throw fe.get();
+ } finally {
+ tpe.shutdownNow();
+ // Make sure a connect that succeeded after we exit the loop
+ // but before we call sem.release() will be cleaned up
+ Util.safeClose( retSock.get() );
+ }
+ }
+
}