Support versions >= 2.7.50 in query_cur_resion

This commit is contained in:
Nobody 2022-05-18 04:38:55 +05:00
parent 673976095a
commit 69450ae7ea

View File

@ -4,6 +4,8 @@ use std::net::TcpStream;
use std::collections::HashMap;
use std::fs;
use std::sync::Arc;
use std::fs::read_to_string; // use instead of std::fs::File
use std::path::Path;
extern crate futures;
extern crate base64;
@ -11,15 +13,21 @@ extern crate actix_web;
extern crate openssl;
use serde::{de, Deserialize, Deserializer, Serialize};
use serde::de::Error;
use futures::executor;
use actix_web::{rt::System, web, get, App, HttpRequest, HttpResponse, HttpServer, Responder, middleware::Logger};
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod, SslVerifyMode, SslOptions, SslMode};
use openssl::rsa::{Rsa, Padding};
use openssl::symm::Cipher;
use rand::{distributions::Alphanumeric, Rng};
use prost::Message;
use mhycrypt;
use openssl::hash::MessageDigest;
use openssl::pkey::{PKey, Private, Public};
use openssl::sign::Signer;
use pretty_env_logger::env_logger::fmt;
use serde::de::Unexpected;
//use openssl::rand;
@ -27,6 +35,34 @@ use serde::de::Unexpected;
#[derive(Clone)]
pub struct DispatchServer {}
// Keys stuff
#[derive(Deserialize,Debug)]
struct KeyInfo {
key_id: u8,
#[serde(deserialize_with = "deserialize_pub_key")]
public_key: Rsa<Public>,
#[serde(deserialize_with = "deserialize_priv_key")]
private_key: Rsa<Private>,
}
fn deserialize_pub_key<'de, D>(deserializer: D) -> Result<Rsa<Public>, D::Error>
where
D: Deserializer<'de>,
{
let public_key_pem: &str = Deserialize::deserialize(deserializer)?;
Rsa::public_key_from_pem(public_key_pem.as_bytes()).map_err(D::Error::custom)
}
fn deserialize_priv_key<'de, D>(deserializer: D) -> Result<Rsa<Private>, D::Error>
where
D: Deserializer<'de>,
{
let private_key_pem: &str = Deserialize::deserialize(deserializer)?;
Rsa::private_key_from_pem(private_key_pem.as_bytes()).map_err(D::Error::custom)
}
#[derive(Deserialize,Debug)]
struct ClientInfo {
version: String,
@ -37,6 +73,7 @@ struct ClientInfo {
channel_id: i32,
sub_channel_id: i32,
account_type: Option<i32>,
key_id: Option<u8>,
}
#[derive(Deserialize,Debug)]
@ -268,8 +305,40 @@ impl DispatchServer {
region_config.encode(&mut region_conf_buf).unwrap();
if c.0.version.contains("2.7.5") || c.0.version.contains("2.8.") {// TODO: use proper version check!
let key_id = match c.0.key_id {
Some(key_id) => key_id,
None => panic!("Client version >= 2.7.50, but it haven't provided key_id!"),
};
let rsa_key_collection = DispatchServer::load_rsa_keys("RSAConfig");
let keys = match rsa_key_collection.get(&key_id) {
Some(keys) => keys,
None => panic!("Unknown key ID {}!", key_id),
};
let mut out_buf: Vec<u8> = Vec::new();
let mut enc_buf: Vec<u8> = vec![0; keys.public_key.size() as usize];
for chunk in region_conf_buf.chunks(245) { // TODO: value hardcoded for the 2048-bit key!
keys.public_key.public_encrypt(chunk, &mut enc_buf, Padding::PKCS1).unwrap();
out_buf.append(&mut enc_buf);
}
let keypair = PKey::from_rsa(keys.private_key.clone()).unwrap();
let mut signer = Signer::new(MessageDigest::sha256(), &keypair).unwrap();
signer.update(&region_conf_buf).unwrap();
let signature = signer.sign_to_vec().unwrap();
return format!("
{{
\"content\": \"{}\",
\"sign\": \"{}\"
}}
", base64::encode(out_buf), base64::encode(signature));
} else {
return base64::encode(region_conf_buf);
}
}
async fn risky_api_check_old(a: web::Json<ActionToCheck>) -> String {
println!("Action: {:?}", a);
@ -634,6 +703,16 @@ impl DispatchServer {
return "127.0.0.1".to_string();
}
fn load_rsa_keys(name: &str) -> HashMap<u8, KeyInfo> {
// Key depo
let path = format!("./{}/{}.json", "keys", name);
let json_file_path = Path::new(&path);
let json_file_str = read_to_string(json_file_path).unwrap_or_else(|_| panic!("File {} not found", path));
let data: Vec<KeyInfo> = serde_json::from_str(&json_file_str).expect(&format!("Error while reading json {}", name));
data.into_iter().map(|ki| (ki.key_id, ki)).collect()
}
fn load_keys(name: &str) -> (Vec<u8>, Vec<u8>) {
// Key
let filename = format!("./{}/{}.key", "keys", name);