package org.openslx.taskmanager.tasks; import java.net.InetSocketAddress; import java.net.Socket; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Date; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.openslx.taskmanager.api.AbstractTask; import com.google.gson.annotations.Expose; import com.jcraft.jsch.ChannelExec; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; public class RemoteReboot extends AbstractTask { @Expose private Client[] clients; @Expose private boolean shutdown; @Expose private int minutes; @Expose private String locationId; @Expose private String locationName; @Expose private String sshkey; @Expose private int port; private JSch sshClient; private Output status = new Output(); private static final String REBOOT_CMD = "/opt/openslx/scripts/idleaction-scheduled_action --detach reboot"; private static final String SHUTDOWN_CMD = "/opt/openslx/scripts/idleaction-scheduled_action --detach poweroff"; @Override protected boolean initTask() { this.setStatusObject( this.status ); if ( minutes < 0 ) { status.addError( "Delay cannot be negative" ); } if ( sshkey == null || sshkey.length() == 0 ) { status.addError( "No SSH key given" ); } if ( port < 1 || port > 65535 ) { status.addError( "Invalid port number" ); } if ( status.error != null ) return false; status.clients = clients; Date shutdownTime = new Date( System.currentTimeMillis() + minutes * 60 * 1000 ); SimpleDateFormat sdf = new SimpleDateFormat( "HH:mm" ); status.time = sdf.format( shutdownTime ); status.locationId = locationId; status.locationName = locationName; JSch.setConfig("StrictHostKeyChecking", "no"); sshClient = new JSch(); try { sshClient.addIdentity("", sshkey.getBytes(), null, null); } catch ( JSchException e ) { status.addError( e.toString() ); return false; } return true; } @Override protected boolean execute() { if ( clients.length == 0 ) return true; final List rebootingClients = new ArrayList<>(); // try to connect to every client and start the reboot/shutdown process ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length ); for ( final Client client : clients ) { if ( client == null || client.clientip == null || client.machineuuid == null ) { status.addError( "null Client or missing ip/uuid in list, ignoring." ); continue; } tp.submit( new Runnable() { public void run() { int ret = -1; try { status.clientStatus.put( client.machineuuid, ClientStatus.CONNECTING ); Session session = sshClient.getSession("root", client.clientip, port); session.connect(5000); ChannelExec channel = (ChannelExec) session.openChannel("exec"); String args = " " + minutes + " " + String.format("'%s'", client.machineuuid.replace("'", "'\\''")); if ( shutdown ) { channel.setCommand( SHUTDOWN_CMD + args ); channel.connect(2000); waitForCommand(channel, 2000); ret = channel.getExitStatus(); status.clientStatus.put( client.machineuuid, minutes == 0 ? ClientStatus.SHUTDOWN : ClientStatus.SHUTDOWN_AT ); } else { channel.setCommand( REBOOT_CMD + args ); channel.connect(2000); waitForCommand(channel, 2000); ret = channel.getExitStatus(); if ( ret == 0 ) { status.clientStatus.put( client.machineuuid, minutes == 0 ? ClientStatus.REBOOTING : ClientStatus.REBOOT_AT ); rebootingClients.add( client ); } } channel.disconnect(); session.disconnect(); } catch ( JSchException e ) { if ( e.toString().contains( "Auth fail" ) ) { status.clientStatus.put( client.machineuuid, ClientStatus.AUTH_FAIL ); ret = 0; } else { status.addError( client.clientip + ": " + e.toString() ); ret = -1; } } if ( ret != 0 ) { if ( ret != -1 ) { status.addError( client.clientip + ": Exit Code " + ret ); } status.clientStatus.put( client.machineuuid, ClientStatus.ERROR ); } } } ); } tp.shutdown(); try { tp.awaitTermination( clients.length * 5, TimeUnit.SECONDS ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } if ( minutes == 0 && rebootingClients.size() > 0 ) { // Give about 3 minutes for reboot, should be plenty // Determine online state if either ssh or winrpc/smb is open final int[] ports; if ( this.port == 22 ) { ports = new int[] { 22, 445 }; } else { ports = new int[] { this.port, 22, 445 }; } // Assume the boot loop takes at least 30 secs, don't even try before that try { Thread.sleep( 30000 ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } long lastcheck = 0; long deadline = System.currentTimeMillis() + 120 * 1000; while ( rebootingClients.size() > 0 && System.currentTimeMillis() < deadline ) { long delay = 10000 - ( System.currentTimeMillis() - lastcheck ); if ( delay > 0 ) { try { Thread.sleep( delay ); } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); return false; } } lastcheck = System.currentTimeMillis(); Iterator it = rebootingClients.iterator(); while ( it.hasNext() ) { Client client = it.next(); if ( isOnline( client.clientip, ports ) ) { it.remove(); status.clientStatus.put( client.machineuuid, ClientStatus.ONLINE ); } } } } // change status of clients that got stuck because of timeouts for ( Map.Entry entry : status.clientStatus.entrySet() ) { ClientStatus value = entry.getValue(); if ( value == ClientStatus.CONNECTING || ( minutes == 0 && value == ClientStatus.REBOOTING ) ) { entry.setValue( ClientStatus.ERROR ); } } return true; } private void waitForCommand(ChannelExec channel, int timeout) { for ( int i = 0; i < timeout / 100; i++ ) { if ( channel.isClosed() ) { break; } try { Thread.sleep(100); } catch (InterruptedException e) {} } } private boolean isOnline( String address, int... ports ) { for ( int port : ports ) { try ( Socket s = new Socket() ) { s.connect( new InetSocketAddress( address, port ), 1000 ); return true; } catch ( Exception ex ) { } } return false; } /** * Output - contains additional status data of this task */ @SuppressWarnings( "unused" ) static class Output { private Map clientStatus = new ConcurrentHashMap<>(); private Client[] clients; private String time; private String locationId; private String locationName; private String error; private synchronized void addError( String e ) { if ( error == null ) { error = e + "\n"; } else { error += e + "\n"; } } } static class Client { @Expose private String machineuuid; @Expose private String clientip; } static enum ClientStatus { CONNECTING, REBOOTING, REBOOT_AT, SHUTDOWN, SHUTDOWN_AT, ONLINE, ERROR, AUTH_FAIL; } }