package org.openslx.taskmanager.tasks; import java.net.InetSocketAddress; import java.net.Socket; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Date; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.openslx.taskmanager.api.AbstractTask; import org.openslx.util.PrioThreadFactory; import com.google.gson.annotations.Expose; import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; public class RemoteReboot extends AbstractTask { static enum Mode { SHUTDOWN, REBOOT, KEXEC_REBOOT, } @Expose private Client[] clients; @Expose private Mode mode; @Expose private int minutes; @Expose private String locationId; @Expose private String sshkey; @Expose private int port; private JSch sshClient; private Output status = new Output(); private static final String BASE_CMD = "/opt/openslx/scripts/idleaction-scheduled_action"; private static final String REBOOT_CMD = BASE_CMD + " reboot"; private static final String KEXEC_CMD = BASE_CMD + " kexec-reboot"; private static final String SHUTDOWN_CMD = BASE_CMD + " poweroff"; @Override protected boolean initTask() { this.setStatusObject( this.status ); if ( minutes < 0 ) { status.addError( "Delay cannot be negative" ); } if ( sshkey == null || sshkey.length() == 0 ) { status.addError( "No SSH key given" ); } if ( port < 1 || port > 65535 ) { status.addError( "Invalid port number" ); } if (mode == null) { status.addError( "Invalid/no mode of operation" ); } if ( status.error != null ) return false; Date shutdownTime = new Date( System.currentTimeMillis() + minutes * 60 * 1000 ); SimpleDateFormat sdf = new SimpleDateFormat( "HH:mm" ); status.time = sdf.format( shutdownTime ); status.locationId = locationId; status.mode = mode; JSch.setConfig( "StrictHostKeyChecking", "no" ); sshClient = new JSch(); try { sshClient.addIdentity( "", sshkey.getBytes(), null, null ); } catch ( JSchException e ) { status.addError( e.toString() ); return false; } return true; } @Override protected boolean execute() { if ( clients.length == 0 ) return true; final List rebootingClients = new ArrayList<>(); // try to connect to every client and start the reboot/shutdown process ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length, new PrioThreadFactory( "RemRebt" ) ); for ( final Client client : clients ) { if ( client == null || client.clientip == null || client.machineuuid == null ) { status.addError( "null Client or missing ip/uuid in list, ignoring." ); continue; } tp.submit( new Runnable() { public void run() { int ret = -123; Session session = null; ChannelExec channel = null; try { status.clientStatus.put( client.machineuuid, ClientStatus.CONNECTING ); session = sshClient.getSession( "root", client.clientip, port ); session.connect( 5000 ); channel = (ChannelExec)session.openChannel( "exec" ); String args = " " + minutes + " " + String.format( "'%s'", client.machineuuid.replace( "'", "'\\''" ) ); if ( mode == Mode.SHUTDOWN ) { ret = execCommand( channel, SHUTDOWN_CMD + args ); status.clientStatus.put( client.machineuuid, minutes == 0 ? ClientStatus.SHUTDOWN : ClientStatus.SHUTDOWN_AT ); } else { if ( mode == Mode.REBOOT ) { ret = execCommand( channel, REBOOT_CMD + args ); } else { ret = execCommand( channel, KEXEC_CMD + args ); } if ( ret == 0 ) { status.clientStatus.put( client.machineuuid, minutes == 0 ? ClientStatus.REBOOTING : ClientStatus.REBOOT_AT ); rebootingClients.add( client ); } } } catch ( JSchException e ) { if ( "Auth fail".equals( e.getMessage() ) || "Auth cancel".equals( e.getMessage() ) ) { status.clientStatus.put( client.machineuuid, ClientStatus.AUTH_FAIL ); ret = 0; } else { status.addError( client.clientip + ": " + e.toString() ); ret = -123; } } finally { if ( session != null ) { try { channel.disconnect(); } catch ( Exception e ) { } try { session.disconnect(); } catch ( Exception e ) { } } } if ( ret != 0 ) { if ( ret != -123 ) { status.addError( client.clientip + ": Exit Code " + ret ); } status.clientStatus.put( client.machineuuid, ClientStatus.ERROR ); } } } ); } tp.shutdown(); try { tp.awaitTermination( clients.length * 5 + 10, TimeUnit.SECONDS ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } if ( minutes == 0 && rebootingClients.size() > 0 ) { // Give about 3 minutes for reboot, should be plenty // Determine online state if either ssh or winrpc/smb is open final int[] ports; if ( this.port == 22 ) { ports = new int[] { 22, 445 }; } else { ports = new int[] { this.port, 22, 445 }; } // Assume the boot loop takes at least 30 secs, don't even try before that try { Thread.sleep( 30000 ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } long lastcheck = 0; long deadline = System.currentTimeMillis() + 120 * 1000; while ( rebootingClients.size() > 0 && System.currentTimeMillis() < deadline ) { long delay = 10000 - ( System.currentTimeMillis() - lastcheck ); if ( delay > 0 ) { try { Thread.sleep( delay ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } } lastcheck = System.currentTimeMillis(); Iterator it = rebootingClients.iterator(); while ( it.hasNext() ) { Client client = it.next(); if ( isOnline( client.clientip, ports ) ) { it.remove(); status.clientStatus.put( client.machineuuid, ClientStatus.ONLINE ); } } } } // change status of clients that got stuck because of timeouts for ( Map.Entry entry : status.clientStatus.entrySet() ) { ClientStatus value = entry.getValue(); if ( value == ClientStatus.CONNECTING || ( minutes == 0 && value == ClientStatus.REBOOTING ) ) { entry.setValue( ClientStatus.ERROR ); } } return true; } private int execCommand( ChannelExec channel, String cmd ) throws JSchException { channel.setCommand( cmd ); channel.connect( 2000 ); for ( int i = 0; i < 2000 / 100; i++ ) { // Wait 2 seconds if ( channel.isClosed() ) { break; } try { Thread.sleep( 100 ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); break; } } return channel.getExitStatus(); } private boolean isOnline( String address, int... ports ) { for ( int port : ports ) { try ( Socket s = new Socket() ) { s.connect( new InetSocketAddress( address, port ), 1000 ); return true; } catch ( Exception ex ) { } } return false; } /** * Output - contains additional status data of this task */ @SuppressWarnings( "unused" ) static class Output { private Map clientStatus = new ConcurrentHashMap<>(); private String time; private String locationId; private String error; private Mode mode; private synchronized void addError( String e ) { if ( error == null ) { error = e + "\n"; } else { error += e + "\n"; } } } static class Client { @Expose private String machineuuid; @Expose private String clientip; } static enum ClientStatus { CONNECTING, REBOOTING, REBOOT_AT, SHUTDOWN, SHUTDOWN_AT, ONLINE, ERROR, AUTH_FAIL; } }