package com.example.sshd.service; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; import java.util.Map; import java.util.Optional; import java.util.Properties; import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.digest.DigestUtils; import org.apache.commons.exec.CommandLine; import org.apache.commons.exec.DefaultExecutor; import org.apache.commons.exec.ExecuteException; import org.apache.commons.exec.PumpStreamHandler; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.sshd.server.session.ServerSession; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @Service public class ReplyService { private static final Logger logger = LoggerFactory.getLogger(ReplyService.class); private static final Logger notFoundLogger = LoggerFactory.getLogger("not_found"); @Autowired Properties hashReplies; @Autowired Properties regexMapping; @Autowired Map ipInfoMapping; @Autowired JdbcService jdbcService; @Autowired GeoIpLocator geoIpLocator; public boolean replyToCommand(String command, OutputStream out, String prompt, ServerSession session) throws IOException { String cmdHash = DigestUtils.md5Hex(command.trim()).toUpperCase(); if (StringUtils.equalsIgnoreCase(command.trim(), "my_geolocation")) { logger.info("[{}] my_geolocation command detected: {}", cmdHash, command.trim()); out.write(String.format("\r\n%s\r\n%s", ipInfoMapping.get(Thread.currentThread().getName()), prompt) .getBytes()); } else if (StringUtils.equalsIgnoreCase(command.trim(), "whoami")) { logger.info("[{}] whoami command detected: {}", cmdHash, command.trim()); out.write(String.format("\r\n%s\r\n%s", session.getUsername(), prompt).getBytes()); } else if (StringUtils.equalsIgnoreCase(command.trim(), "online_geolocations")) { logger.info("[{}] online_geolocations command detected: {}", cmdHash, command.trim()); out.write(String.format("\r\n%s\r\n%s", ipInfoMapping.toString(), prompt).getBytes()); } else if (StringUtils.split(command.trim(), " ").length == 2 && StringUtils.equalsIgnoreCase(StringUtils.split(command.trim(), " ")[0], "get_geolocation")) { String remoteIpInfo = StringUtils.getIfBlank( jdbcService.getRemoteIpInfo(StringUtils.split(command.trim(), " ")[1]), () -> geoIpLocator.getIpLocationInfo(StringUtils.split(command.trim(), " ")[1])); logger.info("[{}] get_geolocation command detected: {}", cmdHash, command.trim()); out.write(String.format("\r\n%s\r\n%s", remoteIpInfo, prompt).getBytes()); } else if (StringUtils.equalsIgnoreCase(command.trim(), "all_geolocations")) { logger.info("[{}] all_geolocations command detected: {}", cmdHash, command.trim()); out.write(String.format("\r\n%s\r\n%s", jdbcService.getAllRemoteIpInfo(), prompt).getBytes()); } else if (StringUtils.equalsIgnoreCase(command.trim(), "exit") || StringUtils.equalsIgnoreCase(command.trim(), "quit")) { logger.info("[{}] Exiting command detected: {}", cmdHash, command.trim()); out.write(String.format("\r\nExiting...\r\n%s", prompt).getBytes()); return true; } else if (hashReplies.containsKey(command.trim())) { logger.info("[{}] Known command detected: {}", cmdHash, command.trim()); String reply = hashReplies.getProperty(command.trim()).replace("\\r", "\r").replace("\\n", "\n") .replace("\\t", "\t"); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } else if (hashReplies.containsKey(cmdHash)) { logger.info("[{}] Known command-hash detected: {}", cmdHash, command.trim()); String reply = hashReplies.getProperty(cmdHash).replace("\\r", "\r").replace("\\n", "\n").replace("\\t", "\t"); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } else if (hashReplies.containsKey(String.format("base64(%s)", cmdHash))) { logger.info("[{}] Known base64-hash detected: {}", cmdHash, command.trim()); String reply = hashReplies.getProperty(String.format("base64(%s)", cmdHash)); reply = new String(Base64.decodeBase64(reply)); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } else { Optional> o = regexMapping.entrySet().stream() .filter(e -> command.trim().matches(((String) e.getKey()))) .map(e -> Pair.of((String) e.getKey(), (String) e.getValue())).findAny(); if (o.isPresent()) { String reply = hashReplies.getProperty(o.get().getRight(), "").replace("\\r", "\r").replace("\\n", "\n") .replace("\\t", "\t"); if (reply.isEmpty()) { logger.info("[{}] Execute cmd for real: {} ({})", cmdHash, command.trim(), o.get()); ByteArrayOutputStream tempOut = new ByteArrayOutputStream(); try { CommandLine cmdLine = CommandLine.parse(command.trim()); DefaultExecutor executor = DefaultExecutor.builder().get(); PumpStreamHandler streamHandler = new PumpStreamHandler(tempOut); executor.setStreamHandler(streamHandler); int exitValue = executor.execute(cmdLine); logger.info("[{}] Result: {} ({})", cmdHash, command.trim(), exitValue); reply = new String(tempOut.toByteArray()).replace("\n", "\r\n"); } catch (ExecuteException e) { logger.info("[{}] Execute cmd failed: {} ({})", cmdHash, command.trim(), o.get(), e); } out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } else { logger.info("[{}] Known pattern detected: {} ({})", cmdHash, command.trim(), o.get()); out.write(String.format("\r\n%s\r\n%s", reply, prompt).getBytes()); } } else { logger.info("[{}] Command not found: {}", cmdHash, command.trim()); notFoundLogger.info("[{}] Command not found: {}", cmdHash, command.trim()); out.write(String.format("\r\nCommand '%s' not found. Try 'exit'.\r\n%s", command.trim(), prompt) .getBytes()); } } return false; } }