From 245b90071cf1f98fc9c70555db590b883011a653 Mon Sep 17 00:00:00 2001
From: klzgrad <kizdiv@gmail.com>
Date: Thu, 10 Jan 2019 00:30:56 -0500
Subject: [PATCH] Support loading config.json

---
 src/net/tools/naive/naive_proxy_bin.cc | 169 ++++++++++++++++++-------
 1 file changed, 123 insertions(+), 46 deletions(-)

diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc
index 7ee14d18a5..0dd8a7ef5b 100644
--- a/src/net/tools/naive/naive_proxy_bin.cc
+++ b/src/net/tools/naive/naive_proxy_bin.cc
@@ -12,6 +12,7 @@
 #include "base/at_exit.h"
 #include "base/command_line.h"
 #include "base/files/file_path.h"
+#include "base/json/json_file_value_serializer.h"
 #include "base/json/json_writer.h"
 #include "base/logging.h"
 #include "base/macros.h"
@@ -71,6 +72,17 @@ constexpr int kExpectedMaxUsers = 8;
 constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
     net::DefineNetworkTrafficAnnotation("naive", "");
 
+struct CommandLine {
+  std::string listen;
+  std::string proxy;
+  bool padding;
+  std::string host_resolver_rules;
+  bool no_log;
+  base::FilePath log;
+  base::FilePath log_net_log;
+  base::FilePath ssl_key_log_file;
+};
+
 struct Params {
   net::NaiveConnection::Protocol protocol;
   std::string listen_addr;
@@ -81,7 +93,6 @@ struct Params {
   base::string16 proxy_pass;
   std::string host_resolver_rules;
   logging::LoggingSettings log_settings;
-  base::FilePath log_path;
   base::FilePath net_log_path;
   base::FilePath ssl_key_path;
 };
@@ -111,6 +122,7 @@ std::unique_ptr<net::URLRequestContext> BuildURLRequestContext(
 
   net::ProxyConfig proxy_config;
   proxy_config.proxy_rules().ParseFromString(params.proxy_url);
+  LOG(INFO) << "Proxying via " << params.proxy_url;
   auto proxy_service =
       net::ConfiguredProxyResolutionService::CreateWithoutProxyResolver(
           std::make_unique<net::ProxyConfigServiceFixed>(
@@ -149,11 +161,9 @@ std::unique_ptr<net::URLRequestContext> BuildURLRequestContext(
   return context;
 }
 
-bool ParseCommandLineFlags(Params* params) {
-  const base::CommandLine& line = *base::CommandLine::ForCurrentProcess();
-
-  if (line.HasSwitch("h") || line.HasSwitch("help")) {
-    std::cout << "Usage: naive [options]\n"
+void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) {
+  if (proc.HasSwitch("h") || proc.HasSwitch("help")) {
+    std::cout << "Usage: naive { OPTIONS | config.json }\n"
                  "\n"
                  "Options:\n"
                  "-h, --help                 Show this message\n"
@@ -163,26 +173,90 @@ bool ParseCommandLineFlags(Params* params) {
                  "--proxy=<proto>://[<user>:<pass>@]<hostname>[:<port>]\n"
                  "                           proto: https, quic\n"
                  "--padding                  Use padding\n"
+                 "--host-resolver-rules=...  Resolver rules\n"
                  "--log[=<path>]             Log to stderr, or file\n"
                  "--log-net-log=<path>       Save NetLog\n"
                  "--ssl-key-log-file=<path>  Save SSL keys for Wireshark\n"
               << std::endl;
     exit(EXIT_SUCCESS);
-    return false;
   }
 
-  if (line.HasSwitch("version")) {
-    std::cout << "Version: " << version_info::GetVersionNumber() << std::endl;
+  if (proc.HasSwitch("version")) {
+    std::cout << "naive " << version_info::GetVersionNumber() << std::endl;
     exit(EXIT_SUCCESS);
-    return false;
   }
 
+  cmdline->listen = proc.GetSwitchValueASCII("listen");
+  cmdline->proxy = proc.GetSwitchValueASCII("proxy");
+  cmdline->padding = proc.HasSwitch("padding");
+  cmdline->host_resolver_rules =
+      proc.GetSwitchValueASCII("host-resolver-rules");
+  cmdline->no_log = !proc.HasSwitch("log");
+  cmdline->log = proc.GetSwitchValuePath("log");
+  cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log");
+  cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file");
+}
+
+void GetCommandLineFromConfig(const base::FilePath& config_path,
+                              CommandLine* cmdline) {
+  JSONFileValueDeserializer reader(config_path);
+  int error_code;
+  std::string error_message;
+  auto value = reader.Deserialize(&error_code, &error_message);
+  if (value == nullptr) {
+    std::cerr << "Error reading " << config_path << ": (" << error_code << ") "
+              << error_message << std::endl;
+    exit(EXIT_FAILURE);
+  }
+  if (!value->is_dict()) {
+    std::cerr << "Invalid config format" << std::endl;
+    exit(EXIT_FAILURE);
+  }
+  if (value->FindKeyOfType("listen", base::Value::Type::STRING)) {
+    cmdline->listen = value->FindKey("listen")->GetString();
+  }
+  if (value->FindKeyOfType("proxy", base::Value::Type::STRING)) {
+    cmdline->proxy = value->FindKey("proxy")->GetString();
+  }
+  cmdline->padding = false;
+  if (value->FindKeyOfType("padding", base::Value::Type::BOOLEAN)) {
+    cmdline->padding = value->FindKey("padding")->GetBool();
+  }
+  if (value->FindKeyOfType("host-resolver-rules", base::Value::Type::STRING)) {
+    cmdline->host_resolver_rules =
+        value->FindKey("host-resolver-rules")->GetString();
+  }
+  cmdline->no_log = true;
+  if (value->FindKeyOfType("log", base::Value::Type::STRING)) {
+    cmdline->no_log = false;
+    cmdline->log =
+        base::FilePath::FromUTF8Unsafe(value->FindKey("log")->GetString());
+  }
+  if (value->FindKeyOfType("log-net-log", base::Value::Type::STRING)) {
+    cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(
+        value->FindKey("log-net-log")->GetString());
+  }
+  if (value->FindKeyOfType("ssl-key-log-file", base::Value::Type::STRING)) {
+    cmdline->ssl_key_log_file = base::FilePath::FromUTF8Unsafe(
+        value->FindKey("ssl-key-log-file")->GetString());
+  }
+}
+
+std::string GetProxyFromURL(const GURL& url) {
+  std::string str = url.GetWithEmptyPath().spec();
+  if (str.size() && str.back() == '/') {
+    str.pop_back();
+  }
+  return str;
+}
+
+bool ParseCommandLine(const CommandLine& cmdline, Params* params) {
   params->protocol = net::NaiveConnection::kSocks5;
   params->listen_addr = "0.0.0.0";
   params->listen_port = 1080;
   url::AddStandardScheme("socks", url::SCHEME_WITH_HOST_AND_PORT);
-  if (line.HasSwitch("listen")) {
-    GURL url(line.GetSwitchValueASCII("listen"));
+  if (!cmdline.listen.empty()) {
+    GURL url(cmdline.listen);
     if (url.scheme() == "socks") {
       params->protocol = net::NaiveConnection::kSocks5;
       params->listen_port = 1080;
@@ -190,7 +264,7 @@ bool ParseCommandLineFlags(Params* params) {
       params->protocol = net::NaiveConnection::kHttp;
       params->listen_port = 8080;
     } else {
-      LOG(ERROR) << "Invalid scheme in --listen";
+      std::cerr << "Invalid scheme in --listen" << std::endl;
       return false;
     }
     if (!url.host().empty()) {
@@ -198,12 +272,12 @@ bool ParseCommandLineFlags(Params* params) {
     }
     if (!url.port().empty()) {
       if (!base::StringToInt(url.port(), &params->listen_port)) {
-        LOG(ERROR) << "Invalid port in --listen";
+        std::cerr << "Invalid port in --listen" << std::endl;
         return false;
       }
       if (params->listen_port <= 0 ||
           params->listen_port > std::numeric_limits<uint16_t>::max()) {
-        LOG(ERROR) << "Invalid port in --listen";
+        std::cerr << "Invalid port in --listen" << std::endl;
         return false;
       }
     }
@@ -212,50 +286,37 @@ bool ParseCommandLineFlags(Params* params) {
   url::AddStandardScheme("quic",
                          url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
   params->proxy_url = "direct://";
-  GURL url(line.GetSwitchValueASCII("proxy"));
-  if (line.HasSwitch("proxy")) {
+  GURL url(cmdline.proxy);
+  GURL::Replacements remove_auth;
+  remove_auth.ClearUsername();
+  remove_auth.ClearPassword();
+  GURL url_no_auth = url.ReplaceComponents(remove_auth);
+  if (!cmdline.proxy.empty()) {
     if (!url.is_valid()) {
-      LOG(ERROR) << "Invalid proxy URL";
+      std::cerr << "Invalid proxy URL" << std::endl;
       return false;
     }
-    if (url.scheme() != "https" && url.scheme() != "quic") {
-      LOG(ERROR) << "Must be HTTPS or QUIC proxy";
-      return false;
-    }
-    params->proxy_url = url::SchemeHostPort(url).Serialize();
+    params->proxy_url = GetProxyFromURL(url_no_auth);
     net::GetIdentityFromURL(url, &params->proxy_user, &params->proxy_pass);
   }
 
-  params->use_padding = false;
-  if (line.HasSwitch("padding")) {
-    params->use_padding = true;
-  }
+  params->use_padding = cmdline.padding;
 
-  if (line.HasSwitch("host-resolver-rules")) {
-    params->host_resolver_rules =
-        line.GetSwitchValueASCII("host-resolver-rules");
-  }
+  params->host_resolver_rules = cmdline.host_resolver_rules;
 
-  if (line.HasSwitch("log")) {
-    params->log_settings.logging_dest = logging::LOG_DEFAULT;
-    params->log_path = line.GetSwitchValuePath("log");
-    if (!params->log_path.empty()) {
+  if (!cmdline.no_log) {
+    if (!cmdline.log.empty()) {
       params->log_settings.logging_dest = logging::LOG_TO_FILE;
-    } else if (params->log_settings.logging_dest == logging::LOG_TO_FILE) {
-      params->log_path = base::FilePath::FromUTF8Unsafe("naive.log");
+      params->log_settings.log_file_path = cmdline.log.value().c_str();
+    } else {
+      params->log_settings.logging_dest = logging::LOG_TO_STDERR;
     }
-    params->log_settings.log_file_path = params->log_path.value().c_str();
   } else {
     params->log_settings.logging_dest = logging::LOG_NONE;
   }
 
-  if (line.HasSwitch("log-net-log")) {
-    params->net_log_path = line.GetSwitchValuePath("log-net-log");
-  }
-
-  if (line.HasSwitch("ssl-key-log-file")) {
-    params->ssl_key_path = line.GetSwitchValuePath("ssl-key-log-file");
-  }
+  params->net_log_path = cmdline.log_net_log;
+  params->ssl_key_path = cmdline.ssl_key_log_file;
 
   return true;
 }
@@ -317,8 +378,22 @@ int main(int argc, char* argv[]) {
 
   base::CommandLine::Init(argc, argv);
 
+  CommandLine cmdline;
   Params params;
-  if (!ParseCommandLineFlags(&params)) {
+  const auto& proc = *base::CommandLine::ForCurrentProcess();
+  const auto& args = proc.GetArgs();
+  if (args.empty()) {
+    if (proc.argv().size() >= 2) {
+      GetCommandLine(proc, &cmdline);
+    } else {
+      auto path = base::FilePath::FromUTF8Unsafe("config.json");
+      GetCommandLineFromConfig(path, &cmdline);
+    }
+  } else {
+    base::FilePath path(args[0]);
+    GetCommandLineFromConfig(path, &cmdline);
+  }
+  if (!ParseCommandLine(cmdline, &params)) {
     return EXIT_FAILURE;
   }
 
@@ -374,6 +449,8 @@ int main(int argc, char* argv[]) {
     LOG(ERROR) << "Failed to listen: " << result;
     return EXIT_FAILURE;
   }
+  LOG(INFO) << "Listening on " << params.listen_addr << ":"
+            << params.listen_port;
 
   net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol,
                               params.use_padding, session, kTrafficAnnotation);