From 7226904e04ddca11655e5295ff40c1e5a3866440 Mon Sep 17 00:00:00 2001 From: soul-walker <31162815+soul-walker@users.noreply.github.com> Date: Tue, 22 Oct 2024 17:24:50 +0800 Subject: [PATCH] =?UTF-8?q?1,=20=E6=B7=BB=E5=8A=A0=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=9F=8E=E5=B8=82=E8=BD=A8=E9=81=93=E4=BA=A4?= =?UTF-8?q?=E9=80=9A=E4=BB=BF=E7=9C=9F=E5=8A=9F=E8=83=BD=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=202,=20=E6=B7=BB=E5=8A=A0mqtt=E5=AE=A2=E6=88=B7=E7=AB=AFcrate?= =?UTF-8?q?=EF=BC=8C=E6=8F=90=E4=BE=9B=E8=AE=A2=E9=98=85=E3=80=81=E5=8F=91?= =?UTF-8?q?=E5=B8=83=E3=80=81=E5=8F=91=E9=80=81=E8=AF=B7=E6=B1=82=E7=AD=89?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 8 + Cargo.lock | 169 ++++- Cargo.toml | 12 +- .../apis/{common.rs => data_options_def.rs} | 2 +- crates/rtss_api/src/apis/draft_data.rs | 11 +- crates/rtss_api/src/apis/feature.rs | 112 +++- .../rtss_api/src/apis/feature_config_def.rs | 17 + crates/rtss_api/src/apis/mod.rs | 3 +- crates/rtss_api/src/apis/release_data.rs | 2 +- crates/rtss_api/src/server.rs | 62 +- crates/rtss_api/src/user_auth/mod.rs | 41 +- crates/rtss_db/src/db_access/mod.rs | 30 + crates/rtss_db/src/model.rs | 18 +- crates/rtss_mqtt/Cargo.toml | 15 + crates/rtss_mqtt/src/error.rs | 14 + crates/rtss_mqtt/src/lib.rs | 612 ++++++++++++++++++ migrations/20240830095636_init.up.sql | 3 + 17 files changed, 1049 insertions(+), 82 deletions(-) rename crates/rtss_api/src/apis/{common.rs => data_options_def.rs} (96%) create mode 100644 crates/rtss_api/src/apis/feature_config_def.rs create mode 100644 crates/rtss_mqtt/Cargo.toml create mode 100644 crates/rtss_mqtt/src/error.rs create mode 100644 crates/rtss_mqtt/src/lib.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index ce8c14c..f15e1ed 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,6 +2,8 @@ "cSpell.words": [ "chrono", "cpus", + "dashmap", + "eventloop", "Graphi", "graphiql", "hashbrown", @@ -11,7 +13,11 @@ "Joylink", "jsonwebtoken", "mplj", + "Mqtt", + "mqttbytes", "Neng", + "nextval", + "oneshot", "plpgsql", "prost", "proto", @@ -20,6 +26,8 @@ "repr", "reqwest", "rtss", + "rumqtt", + "rumqttc", "sqlx", "sysinfo", "thiserror", diff --git a/Cargo.lock b/Cargo.lock index 831b6e9..0d1b94f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -121,9 +121,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.88" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356" +checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" [[package]] name = "ascii_utils" @@ -270,9 +270,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.81" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", @@ -619,9 +619,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" dependencies = [ "serde", ] @@ -658,9 +658,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.17" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ "clap_builder", "clap_derive", @@ -668,9 +668,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.17" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" dependencies = [ "anstream", "anstyle", @@ -680,9 +680,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.13" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ "heck", "proc-macro2", @@ -776,6 +776,16 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1407,10 +1417,10 @@ dependencies = [ "http", "hyper", "hyper-util", - "rustls", + "rustls 0.23.13", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.0", "tower-service", "webpki-roots", ] @@ -1816,6 +1826,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "ordered-multimap" version = "0.6.0" @@ -2119,7 +2135,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls", + "rustls 0.23.13", "socket2", "thiserror", "tokio", @@ -2136,7 +2152,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls", + "rustls 0.23.13", "slab", "thiserror", "tinyvec", @@ -2301,7 +2317,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls", + "rustls 0.23.13", "rustls-pemfile", "rustls-pki-types", "serde", @@ -2309,7 +2325,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.0", "tower-service", "url", "wasm-bindgen", @@ -2440,6 +2456,20 @@ dependencies = [ "tracing-wasm", ] +[[package]] +name = "rtss_mqtt" +version = "0.1.0" +dependencies = [ + "async-trait", + "bytes", + "lazy_static", + "rtss_db", + "rtss_log", + "rumqttc", + "thiserror", + "tokio", +] + [[package]] name = "rtss_sim_manage" version = "0.1.0" @@ -2482,6 +2512,25 @@ dependencies = [ "rtss_log", ] +[[package]] +name = "rumqttc" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1568e15fab2d546f940ed3a21f48bbbd1c494c90c99c4481339364a497f94a9" +dependencies = [ + "bytes", + "flume", + "futures-util", + "log", + "rustls-native-certs", + "rustls-pemfile", + "rustls-webpki", + "thiserror", + "tokio", + "tokio-rustls 0.25.0", + "url", +] + [[package]] name = "rust-ini" version = "0.19.0" @@ -2517,6 +2566,20 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls" version = "0.23.13" @@ -2531,6 +2594,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "2.1.3" @@ -2570,12 +2646,44 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "schannel" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.210" @@ -2598,9 +2706,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.125" +version = "1.0.131" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" +checksum = "67d42a0bd4ac281beff598909bb56a86acaf979b84483e1c79c10dcaf98f8cf3" dependencies = [ "itoa", "memchr", @@ -3072,18 +3180,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", @@ -3182,13 +3290,24 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.4", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls", + "rustls 0.23.13", "rustls-pki-types", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index e09bfb5..f1988e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ bevy_ecs = "0.14" bevy_time = "0.14" rayon = "1.10" tokio = { version = "1.40", features = ["macros", "rt-multi-thread"] } -thiserror = "1.0" +thiserror = "1.0.64" sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", @@ -23,8 +23,12 @@ sqlx = { version = "0.8", features = [ "chrono", ] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0.125" -anyhow = "1.0" +serde_json = "1.0.131" +anyhow = "1.0.90" +async-trait = "0.1.83" +bytes = "1.7.2" +lazy_static = "1.5.0" + [dependencies] tokio = { version = "1.39.3", features = ["macros", "rt-multi-thread"] } @@ -33,6 +37,6 @@ rtss_api = { path = "crates/rtss_api" } rtss_db = { path = "crates/rtss_db" } serde = { workspace = true } config = "0.14.0" -clap = { version = "4.5", features = ["derive"] } +clap = { version = "4.5.20", features = ["derive"] } enum_dispatch = "0.3" anyhow = { workspace = true } diff --git a/crates/rtss_api/src/apis/common.rs b/crates/rtss_api/src/apis/data_options_def.rs similarity index 96% rename from crates/rtss_api/src/apis/common.rs rename to crates/rtss_api/src/apis/data_options_def.rs index c59a857..b0ea32b 100644 --- a/crates/rtss_api/src/apis/common.rs +++ b/crates/rtss_api/src/apis/data_options_def.rs @@ -10,7 +10,7 @@ pub trait DataOptions: InputType + OutputType + Serialize + DeserializeOwned { impl DataOptions for Value { fn to_data_options_filter_clause(&self) -> String { - format!("options @> '{}'", self) + format!("{} @> '{}'", DraftDataColumn::Options.name(), self) } } diff --git a/crates/rtss_api/src/apis/draft_data.rs b/crates/rtss_api/src/apis/draft_data.rs index 7166080..bd826b2 100644 --- a/crates/rtss_api/src/apis/draft_data.rs +++ b/crates/rtss_api/src/apis/draft_data.rs @@ -11,7 +11,7 @@ use serde_json::Value; use crate::apis::{PageDto, PageQueryDto}; use crate::loader::RtssDbLoader; -use super::common::{DataOptions, IscsDataOptions}; +use super::data_options_def::{DataOptions, IscsDataOptions}; use super::release_data::ReleaseDataId; use super::user::UserId; @@ -115,7 +115,7 @@ impl DraftDataMutation { .data::()? .query_user(&ctx.data::()?.0) .await?; - input = input.with_user_id(user.id_i32()); + input = input.with_data_type_and_user_id(DataType::Iscs, user.id_i32()); let db_accessor = ctx.data::()?; let draft_data = db_accessor.create_draft_data(input.into()).await?; Ok(draft_data.into()) @@ -207,6 +207,8 @@ impl DraftDataMutation { #[derive(Debug, InputObject)] #[graphql(concrete(name = "CreateDraftIscsDto", params(IscsDataOptions)))] pub struct CreateDraftDataDto { + #[graphql(skip)] + pub data_type: Option, pub name: String, pub options: Option, #[graphql(skip)] @@ -214,7 +216,8 @@ pub struct CreateDraftDataDto { } impl CreateDraftDataDto { - pub fn with_user_id(mut self, id: i32) -> Self { + pub fn with_data_type_and_user_id(mut self, data_type: DataType, id: i32) -> Self { + self.data_type = Some(data_type); self.user_id = Some(id); self } @@ -224,7 +227,7 @@ impl From> for rtss_db::CreateDraftData { fn from(value: CreateDraftDataDto) -> Self { let cdd = Self::new( &value.name, - DataType::Iscs, + value.data_type.expect("need data_type"), value.user_id.expect("CreateDraftDataDto need user_id"), ); if value.options.is_some() { diff --git a/crates/rtss_api/src/apis/feature.rs b/crates/rtss_api/src/apis/feature.rs index c548b99..ba597aa 100644 --- a/crates/rtss_api/src/apis/feature.rs +++ b/crates/rtss_api/src/apis/feature.rs @@ -1,18 +1,22 @@ +use crate::{ + apis::{PageDto, PageQueryDto}, + loader::RtssDbLoader, + user_auth::{RoleGuard, Token, UserAuthCache}, +}; use async_graphql::{ dataloader::DataLoader, ComplexObject, Context, InputObject, Object, SimpleObject, }; use chrono::NaiveDateTime; -use rtss_db::{FeatureAccessor, RtssDbAccessor}; +use rtss_db::{CreateFeature, FeatureAccessor, RtssDbAccessor, UpdateFeature}; use rtss_dto::common::FeatureType; +use rtss_dto::common::Role; use serde_json::Value; -use crate::{ - apis::{PageDto, PageQueryDto}, - loader::RtssDbLoader, +use super::{ + feature_config_def::{FeatureConfig, UrFeatureConfig}, + user::UserId, }; -use super::user::UserId; - #[derive(Default)] pub struct FeatureQuery; @@ -21,7 +25,8 @@ pub struct FeatureMutation; #[Object] impl FeatureQuery { - /// 分页查询特征(系统管理) + /// 分页查询功能feature(系统管理) + #[graphql(guard = "RoleGuard::new(Role::Admin)")] async fn feature_paging( &self, ctx: &Context<'_>, @@ -35,14 +40,16 @@ impl FeatureQuery { Ok(paging.into()) } - /// id获取特征 + /// id获取功能feature + #[graphql(guard = "RoleGuard::new(Role::User)")] async fn feature(&self, ctx: &Context<'_>, id: i32) -> async_graphql::Result { let dba = ctx.data::()?; let feature = dba.get_feature(id).await?; Ok(feature.into()) } - /// id列表获取特征 + /// id列表获取功能feature列表 + #[graphql(guard = "RoleGuard::new(Role::User)")] async fn features( &self, ctx: &Context<'_>, @@ -56,7 +63,8 @@ impl FeatureQuery { #[Object] impl FeatureMutation { - /// 上下架特征 + /// 上下架功能feature + #[graphql(guard = "RoleGuard::new(Role::Admin)")] async fn publish_feature( &self, ctx: &Context<'_>, @@ -67,6 +75,90 @@ impl FeatureMutation { let feature = dba.set_feature_published(id, is_published).await?; Ok(feature.into()) } + + /// 创建城轨仿真功能feature + #[graphql(guard = "RoleGuard::new(Role::Admin)")] + async fn create_ur_feature( + &self, + ctx: &Context<'_>, + mut input: CreateFeatureDto, + ) -> async_graphql::Result { + let dba = ctx.data::()?; + let user = ctx + .data::()? + .query_user(&ctx.data::()?.0) + .await?; + input = input.with_feature_type_and_user_id(FeatureType::Ur, user.id_i32()); + let feature = dba.create_feature(&input.into()).await?; + Ok(feature.into()) + } + + /// 更新城轨仿真功能feature + #[graphql(guard = "RoleGuard::new(Role::Admin)")] + async fn update_ur_feature( + &self, + ctx: &Context<'_>, + input: UpdateFeatureDto, + ) -> async_graphql::Result { + let dba = ctx.data::()?; + let feature = dba.update_feature(&input.into()).await?; + Ok(feature.into()) + } +} + +#[derive(Debug, InputObject)] +#[graphql(concrete(name = "UpdateUrFeatureDto", params(UrFeatureConfig)))] +pub struct UpdateFeatureDto { + pub id: i32, + pub name: String, + pub description: String, + pub config: T, + #[graphql(skip)] + pub user_id: Option, +} + +impl From> for UpdateFeature { + fn from(value: UpdateFeatureDto) -> Self { + Self { + id: value.id, + name: value.name, + description: value.description, + config: serde_json::to_value(&value.config).expect("config is to_value failed"), + updater_id: value.user_id.expect("user_id must be set"), + } + } +} + +#[derive(Debug, InputObject)] +#[graphql(concrete(name = "CreateUrFeatureDto", params(UrFeatureConfig)))] +pub struct CreateFeatureDto { + #[graphql(skip)] + pub feature_type: Option, + pub name: String, + pub description: String, + pub config: T, + #[graphql(skip)] + pub user_id: Option, +} + +impl From> for CreateFeature { + fn from(value: CreateFeatureDto) -> Self { + Self { + feature_type: value.feature_type.expect("feature_type must be set"), + name: value.name, + description: value.description, + config: serde_json::to_value(&value.config).expect("config is to_value failed"), + creator_id: value.user_id.expect("user_id must be set"), + } + } +} + +impl CreateFeatureDto { + fn with_feature_type_and_user_id(mut self, feature_type: FeatureType, uid: i32) -> Self { + self.feature_type = Some(feature_type); + self.user_id = Some(uid); + self + } } #[derive(Debug, InputObject)] diff --git a/crates/rtss_api/src/apis/feature_config_def.rs b/crates/rtss_api/src/apis/feature_config_def.rs new file mode 100644 index 0000000..69e2f7b --- /dev/null +++ b/crates/rtss_api/src/apis/feature_config_def.rs @@ -0,0 +1,17 @@ +use async_graphql::{InputObject, InputType, OutputType, SimpleObject}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::Value; + +pub trait FeatureConfig: InputType + OutputType + Serialize + DeserializeOwned {} + +impl FeatureConfig for Value {} + +/// UR功能配置 +#[derive(Debug, Clone, InputObject, SimpleObject, Serialize, Deserialize)] +#[graphql(input_name = "UrFeatureConfigInput")] +pub struct UrFeatureConfig { + /// 电子地图id + pub ems: Vec, +} + +impl FeatureConfig for UrFeatureConfig {} diff --git a/crates/rtss_api/src/apis/mod.rs b/crates/rtss_api/src/apis/mod.rs index 89127b7..47c5846 100644 --- a/crates/rtss_api/src/apis/mod.rs +++ b/crates/rtss_api/src/apis/mod.rs @@ -8,9 +8,10 @@ mod sys_info; use simulation_definition::*; use user::{UserMutation, UserQuery}; -mod common; +mod data_options_def; mod draft_data; mod feature; +mod feature_config_def; mod release_data; mod simulation; mod user; diff --git a/crates/rtss_api/src/apis/release_data.rs b/crates/rtss_api/src/apis/release_data.rs index 16bbd74..442ee7c 100644 --- a/crates/rtss_api/src/apis/release_data.rs +++ b/crates/rtss_api/src/apis/release_data.rs @@ -14,7 +14,7 @@ use serde_json::Value; use crate::apis::draft_data::DraftDataDto; use crate::loader::RtssDbLoader; -use super::common::{DataOptions, IscsDataOptions}; +use super::data_options_def::{DataOptions, IscsDataOptions}; use super::user::UserId; use super::{PageDto, PageQueryDto}; diff --git a/crates/rtss_api/src/server.rs b/crates/rtss_api/src/server.rs index dbf7e13..a1dbe69 100644 --- a/crates/rtss_api/src/server.rs +++ b/crates/rtss_api/src/server.rs @@ -10,6 +10,7 @@ use axum::{ }; use dataloader::DataLoader; use http::{playground_source, GraphQLPlaygroundConfig}; +use rtss_db::RtssDbAccessor; use rtss_log::tracing::{debug, info}; use tokio::net::TcpListener; use tower_http::cors::CorsLayer; @@ -47,7 +48,12 @@ impl ServerConfig { } pub async fn serve(config: ServerConfig) -> anyhow::Result<()> { - let schema = new_schema(config.clone()).await; + let client = config + .user_auth_client + .clone() + .expect("user auth client not configured"); + let dba = rtss_db::get_db_accessor(&config.database_url).await; + let schema = new_schema(SchemaOptions::new(client, dba)); let app = Router::new() .route("/", get(graphiql).post(graphql_handler)) @@ -88,18 +94,52 @@ async fn graphiql() -> impl IntoResponse { pub type RtssAppSchema = Schema; -pub async fn new_schema(config: ServerConfig) -> RtssAppSchema { - let client = config - .user_auth_client - .expect("user auth client not configured"); - let user_info_cache = crate::user_auth::UserAuthCache::new(client.clone()); - let dba = rtss_db::get_db_accessor(&config.database_url).await; - let loader = RtssDbLoader::new(dba.clone()); +pub struct SchemaOptions { + pub user_auth_client: UserAuthClient, + pub user_info_cache: user_auth::UserAuthCache, + pub rtss_dba: RtssDbAccessor, +} + +impl SchemaOptions { + pub fn new(user_auth_client: UserAuthClient, rtss_dba: RtssDbAccessor) -> Self { + let user_info_cache = user_auth::UserAuthCache::new(user_auth_client.clone()); + Self { + user_auth_client, + user_info_cache, + rtss_dba, + } + } +} + +pub fn new_schema(options: SchemaOptions) -> RtssAppSchema { + let loader = RtssDbLoader::new(options.rtss_dba.clone()); Schema::build(Query::default(), Mutation::default(), EmptySubscription) - .data(client) - .data(user_info_cache) - .data(dba) + .data(options.user_auth_client) + .data(options.user_info_cache) + .data(options.rtss_dba) .data(DataLoader::new(loader, tokio::spawn)) // .data(MutexSimulationManager::default()) .finish() } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_new_schema() { + let dba = + rtss_db::get_db_accessor("postgresql://joylink:Joylink@0503@localhost:5432/joylink") + .await; + let _ = new_schema(SchemaOptions::new( + crate::UserAuthClient { + base_url: "".to_string(), + login_url: "".to_string(), + logout_url: "".to_string(), + user_info_url: "".to_string(), + sync_user_url: "".to_string(), + }, + dba, + )); + } +} diff --git a/crates/rtss_api/src/user_auth/mod.rs b/crates/rtss_api/src/user_auth/mod.rs index 8c3d6fa..cb8d4ca 100644 --- a/crates/rtss_api/src/user_auth/mod.rs +++ b/crates/rtss_api/src/user_auth/mod.rs @@ -276,9 +276,6 @@ impl UserAuthCache { #[cfg(test)] mod tests { - use anyhow::Ok; - - use rtss_log::tracing::Level; use super::*; @@ -302,25 +299,25 @@ mod tests { println!("{:?}", dt); } - #[tokio::test] - async fn test_user_auth_cache() -> anyhow::Result<()> { - rtss_log::Logging::default().with_level(Level::DEBUG).init(); - let client = UserAuthClient { - base_url: "http://192.168.33.233/rtss-server".to_string(), - login_url: "/api/login".to_string(), - logout_url: "/api/login/logout".to_string(), - user_info_url: "/api/login/getUserInfo".to_string(), - sync_user_url: "/api/userinfo/list/all".to_string(), - }; - let cache = UserAuthCache::new(client.clone()); - let token = cache.client.login(LoginInfo::default()).await?; - let user = cache.query_user(&token).await?; - println!("token: {}, {:?}", token, user); - assert_eq!(cache.len(), 1); + // #[tokio::test] + // async fn test_user_auth_cache() -> anyhow::Result<()> { + // rtss_log::Logging::default().with_level(Level::DEBUG).init(); + // let client = UserAuthClient { + // base_url: "http://192.168.33.233/rtss-server".to_string(), + // login_url: "/api/login".to_string(), + // logout_url: "/api/login/logout".to_string(), + // user_info_url: "/api/login/getUserInfo".to_string(), + // sync_user_url: "/api/userinfo/list/all".to_string(), + // }; + // let cache = UserAuthCache::new(client.clone()); + // let token = cache.client.login(LoginInfo::default()).await?; + // let user = cache.query_user(&token).await?; + // println!("token: {}, {:?}", token, user); + // assert_eq!(cache.len(), 1); - let user_list = client.query_all_users(&Token(token)).await?; - println!("{:?}", user_list); + // let user_list = client.query_all_users(&Token(token)).await?; + // println!("{:?}", user_list); - Ok(()) - } + // Ok(()) + // } } diff --git a/crates/rtss_db/src/db_access/mod.rs b/crates/rtss_db/src/db_access/mod.rs index bbd5300..adaa4d4 100644 --- a/crates/rtss_db/src/db_access/mod.rs +++ b/crates/rtss_db/src/db_access/mod.rs @@ -7,6 +7,8 @@ pub use user::*; mod feature; pub use feature::*; +use crate::{model::MqttClientIdSeq, DbAccessError}; + #[derive(Clone)] pub struct RtssDbAccessor { pool: sqlx::PgPool, @@ -16,9 +18,37 @@ impl RtssDbAccessor { pub fn new(pool: sqlx::PgPool) -> Self { RtssDbAccessor { pool } } + + pub async fn get_next_mqtt_client_id(&self) -> Result { + let seq_name = MqttClientIdSeq::Name.name(); + let next = sqlx::query_scalar(&format!("SELECT nextval('{}')", seq_name)) + .fetch_one(&self.pool) + .await?; + Ok(next) + } } pub async fn get_db_accessor(url: &str) -> RtssDbAccessor { let pool = sqlx::PgPool::connect(url).await.expect("连接数据库失败"); RtssDbAccessor::new(pool) } + +#[cfg(test)] +mod tests { + use super::*; + use rtss_log::tracing::{self, Level}; + use sqlx::PgPool; + + // You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here. + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_get_mqtt_client_id(pool: PgPool) -> Result<(), DbAccessError> { + rtss_log::Logging::default().with_level(Level::DEBUG).init(); + let accessor = crate::db_access::RtssDbAccessor::new(pool); + for _ in 0..10 { + let id = accessor.get_next_mqtt_client_id().await?; + tracing::info!("id = {}", id); + assert!(id > 0); + } + Ok(()) + } +} diff --git a/crates/rtss_db/src/model.rs b/crates/rtss_db/src/model.rs index 1f5d10e..14377de 100644 --- a/crates/rtss_db/src/model.rs +++ b/crates/rtss_db/src/model.rs @@ -7,6 +7,18 @@ use sqlx::types::{ use crate::common::TableColumn; +pub enum MqttClientIdSeq { + Name, +} + +impl MqttClientIdSeq { + pub fn name(&self) -> &str { + match self { + MqttClientIdSeq::Name => "rtss.mqtt_client_id_seq", + } + } +} + #[derive(Debug)] pub enum UserColumn { Table, @@ -101,7 +113,7 @@ pub struct ReleaseDataModel { /// 数据库表 rtss.release_data_version 列映射 #[derive(Debug)] -pub(crate) enum ReleaseDataVersionColumn { +pub enum ReleaseDataVersionColumn { Table, Id, ReleaseDataId, @@ -128,7 +140,7 @@ pub struct ReleaseDataVersionModel { /// 数据库表 rtss.feature 列映射 #[derive(Debug)] #[allow(dead_code)] -pub(crate) enum FeatureColumn { +pub enum FeatureColumn { Table, Id, FeatureType, @@ -160,7 +172,7 @@ pub struct FeatureModel { /// 数据库表 rtss.user_config 列映射 #[derive(Debug)] #[allow(dead_code)] -pub(crate) enum UserConfigColumn { +pub enum UserConfigColumn { Table, Id, UserId, diff --git a/crates/rtss_mqtt/Cargo.toml b/crates/rtss_mqtt/Cargo.toml new file mode 100644 index 0000000..5f640c5 --- /dev/null +++ b/crates/rtss_mqtt/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "rtss_mqtt" +version = "0.1.0" +edition = "2021" + +[dependencies] +rumqttc = { version = "0.24.0", features = ["url"] } +tokio = { workspace = true } +async-trait = { workspace = true } +bytes = { workspace = true } +lazy_static = { workspace = true } +thiserror = { workspace = true } + +rtss_db = { path = "../rtss_db" } +rtss_log = { path = "../rtss_log" } diff --git a/crates/rtss_mqtt/src/error.rs b/crates/rtss_mqtt/src/error.rs new file mode 100644 index 0000000..83b4c5f --- /dev/null +++ b/crates/rtss_mqtt/src/error.rs @@ -0,0 +1,14 @@ +use rumqttc::v5::ClientError; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum MqttClientError { + #[error("未知的Mqtt客户端错误")] + Unknown, + #[error("客户端已设置")] + AlreadySet, + #[error("rumqttc 错误: {0}")] + ClientError(#[from] ClientError), + #[error("全局客户端未设置")] + NoClient, +} diff --git a/crates/rtss_mqtt/src/lib.rs b/crates/rtss_mqtt/src/lib.rs new file mode 100644 index 0000000..d9f72d2 --- /dev/null +++ b/crates/rtss_mqtt/src/lib.rs @@ -0,0 +1,612 @@ +use std::{ + any::TypeId, + collections::HashMap, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, + }, + task::Waker, + time::Duration, +}; + +use bytes::Bytes; +use lazy_static::lazy_static; +use rtss_log::tracing::{debug, error, info}; +use rumqttc::{ + v5::{ + mqttbytes::{ + v5::{Packet, Publish, PublishProperties}, + QoS, + }, + AsyncClient, Event, EventLoop, MqttOptions, + }, + Outgoing, +}; +use tokio::{sync::oneshot, time::timeout}; + +mod error; +use error::MqttClientError; + +lazy_static! { +/// 全局静态MqttClient实例 +static ref MQTT_CLIENT: tokio::sync::Mutex> = tokio::sync::Mutex::new(None); +} + +/// 设置全局MqttClient实例 +pub async fn set_global_mqtt_client(client: MqttClient) -> Result<(), MqttClientError> { + let mut mqtt_client = MQTT_CLIENT.lock().await; + if mqtt_client.is_some() { + return Err(MqttClientError::AlreadySet); + } + *mqtt_client = Some(client); + Ok(()) +} + +pub async fn get_global_mqtt_client() -> Option { + let mqtt_client = MQTT_CLIENT.lock().await; + mqtt_client.clone() +} + +pub struct MqttClientOptions { + id: String, + options: MqttOptions, + request_timeout: Duration, +} + +impl MqttClientOptions { + pub fn new(id: &str, url: &str) -> Self { + Self { + id: id.to_string(), + options: MqttOptions::parse_url(format!("{}?client_id={}", url, id)) + .expect("解析mqtt url失败"), + request_timeout: Duration::from_secs(5), + } + } + + pub fn set_request_timeout(&mut self, timeout: Duration) -> &mut Self { + self.request_timeout = timeout; + self + } + + pub fn set_credentials(&mut self, username: &str, password: &str) -> &mut Self { + self.options.set_credentials(username, password); + self + } + + pub fn build(&mut self) -> MqttClient { + self.options.set_keep_alive(Duration::from_secs(10)); + let (client, eventloop) = AsyncClient::new(self.options.clone(), 10); + + let subscriptions = SubscribeHandlerMap::new(); + let loop_sub = subscriptions.clone(); + tokio::spawn(async move { + MqttClient::handle_connection_loop(eventloop, loop_sub).await; + }); + + MqttClient { + id: self.id.clone(), + request_timeout: self.request_timeout, + client, + request_id: Arc::new(AtomicU64::new(0)), + subscriptions, + } + } +} + +/// MQTT客户端 +/// id: 客户端ID,从数据库的id序列中获取 +/// 客户端具有的功能: +/// 1. 启动 +/// 2. 订阅 +/// 3. 发布 +/// 4. 实现类似http的请求相应功能 +/// 5. 断开连接 +#[derive(Clone)] +pub struct MqttClient { + id: String, + request_timeout: Duration, + client: AsyncClient, + request_id: Arc, + subscriptions: SubscribeHandlerMap, +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub struct HandlerId(TypeId); + +#[derive(Clone)] +pub struct SubscribeHandlerMap { + sub_handlers: Arc>>, +} + +impl SubscribeHandlerMap { + fn new() -> Self { + Self { + sub_handlers: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn insert( + &self, + topic: &str, + handler_id: HandlerId, + handler: Arc, + ) -> HandlerId { + self.sub_handlers + .lock() + .unwrap() + .entry(topic.to_string()) + .or_insert_with(MessageHandlerMap::new) + .insert(handler_id, handler); + handler_id + } + + fn remove(&self, topic: &str, handler_id: HandlerId) { + if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get_mut(topic) { + topic_handlers.remove(handler_id); + } + } + + #[allow(dead_code)] + fn remove_all(&self, topic: &str) { + if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get_mut(topic) { + topic_handlers.remove_all(); + } + } + + fn get_handlers(&self, topic: &str) -> Option>> { + if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get(topic) { + Some(topic_handlers.values()) + } else { + None + } + } + + #[allow(dead_code)] + fn get_mut(&self, topic: &str) -> Option { + self.sub_handlers.lock().unwrap().get(topic).cloned() + } + + #[allow(dead_code)] + fn is_topic_empty(&self, topic: &str) -> bool { + if let Some(topic_handlers) = self.sub_handlers.lock().unwrap().get(topic) { + topic_handlers.is_empty() + } else { + true + } + } + + #[allow(dead_code)] + fn is_empty(&self) -> bool { + self.sub_handlers.lock().unwrap().is_empty() + } + + fn clear(&self) { + self.sub_handlers.lock().unwrap().clear(); + } +} + +#[derive(Clone)] +struct MessageHandlerMap { + handlers: Arc>>>, +} + +impl MessageHandlerMap { + fn new() -> Self { + Self { + handlers: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn insert(&self, handler_id: HandlerId, handler: Arc) { + self.handlers.lock().unwrap().insert(handler_id, handler); + } + + /// 移除处理器,返回剩余处理器数量 + fn remove(&self, handler_id: HandlerId) -> Option> { + self.handlers.lock().unwrap().remove(&handler_id) + } + + #[allow(dead_code)] + fn remove_all(&self) { + self.handlers.lock().unwrap().clear(); + } + + fn values(&self) -> Vec> { + self.handlers.lock().unwrap().values().cloned().collect() + } + + #[allow(dead_code)] + fn is_empty(&self) -> bool { + self.handlers.lock().unwrap().is_empty() + } +} + +#[must_use = "this `SubscribeTopicHandler` implements `Drop`, which will unregister the handler"] +#[derive(Clone)] +pub struct SubscribeTopicHandler { + topic: String, + handler_id: HandlerId, + handler_map: SubscribeHandlerMap, +} + +impl SubscribeTopicHandler { + pub fn new(topic: &str, handler_id: HandlerId, handler_map: SubscribeHandlerMap) -> Self { + Self { + topic: topic.to_string(), + handler_id, + handler_map, + } + } + + pub fn unregister(&self) { + self.handler_map.remove(&self.topic, self.handler_id); + } +} + +/// 订阅消息处理器 +#[async_trait::async_trait] +pub trait MessageHandler: Send + Sync { + async fn handle(&self, publish: Publish); +} + +/// 为闭包实现消息处理器 +#[async_trait::async_trait] +impl MessageHandler for F +where + F: Fn(Publish) + Sync + Send, +{ + async fn handle(&self, publish: Publish) { + self(publish); + } +} + +impl MqttClient { + pub async fn close(&self) -> Result<(), MqttClientError> { + self.client.disconnect().await?; + // 清空订阅处理器 + self.subscriptions.clear(); + Ok(()) + } + + pub fn id(&self) -> &str { + &self.id + } + + pub async fn subscribe(&self, topic: &str, qos: QoS) -> Result<(), MqttClientError> { + self.client.subscribe(topic, qos).await?; + Ok(()) + } + + pub async fn unsubscribe(&self, topic: &str) -> Result<(), MqttClientError> { + self.client.unsubscribe(topic).await?; + Ok(()) + } + + pub fn register_topic_handler(&self, topic: &str, handler: H) -> SubscribeTopicHandler + where + H: MessageHandler + 'static, + { + let handler_id = HandlerId(TypeId::of::()); + self.subscriptions + .insert(topic, handler_id, Arc::new(handler)); + SubscribeTopicHandler::new(topic, handler_id, self.subscriptions.clone()) + } + + pub fn unregister_topic(&self, topic: &str) { + self.subscriptions.remove_all(topic); + } + + pub fn topic_handler_count(&self, topic: &str) -> usize { + if let Some(topic_handlers) = self.subscriptions.get_handlers(topic) { + topic_handlers.len() + } else { + 0 + } + } + + pub async fn publish( + &self, + topic: &str, + qos: QoS, + payload: Vec, + ) -> Result<(), MqttClientError> { + self.client.publish(topic, qos, false, payload).await?; + Ok(()) + } + + pub fn next_request_id(&self) -> u64 { + self.request_id.fetch_add(1, Ordering::Relaxed) + } + + /// 发送请求并等待响应 + pub async fn request( + &self, + topic: &str, + qos: QoS, + payload: Vec, + ) -> Result { + // 订阅响应主题 + let response_topic = format!("{}/{}/resp/{}", self.id, topic, self.next_request_id()); + self.subscribe(&response_topic, QoS::ExactlyOnce).await?; + // 创建请求future + let response_future = MqttResponseFuture::new(&response_topic, self.request_timeout); + // 注册响应处理器 + let response_handler = + self.register_topic_handler(&response_topic, response_future.clone()); + // 发布请求 + let property = PublishProperties { + response_topic: Some(response_topic.clone().into()), + ..Default::default() + }; + self.client + .publish_with_properties(topic, qos, false, payload, property) + .await?; + // 等待响应 + let resp = response_future.await; + // 注销响应处理器并取消订阅 + response_handler.unregister(); + self.unsubscribe(&response_topic).await?; + Ok(resp) + } + + async fn handle_connection_loop(mut eventloop: EventLoop, subscriptions: SubscribeHandlerMap) { + while let Ok(notification) = eventloop.poll().await { + match notification { + Event::Incoming(Packet::Publish(publish)) => { + debug!("Received message: {:?}", publish); + let topic: String = String::from_utf8_lossy(&publish.topic).to_string(); + + if let Some(topic_handlers) = subscriptions.get_handlers(&topic) { + for handler in topic_handlers { + let handler = handler.clone(); + let p = publish.clone(); + tokio::spawn(async move { + handler.handle(p).await; + }); + } + } + } + Event::Outgoing(Outgoing::Disconnect) => { + info!("Disconnected to the broker"); + break; + } + Event::Incoming(Packet::Disconnect(disconnect)) => { + info!("Disconnected from the broker: {:?}", disconnect); + break; + } + _ => { + debug!("Unhandled event: {:?}", notification); + } + } + } + } +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum MqttResponseState { + Waiting, + Received, + Timeout, +} + +/// MQTT请求响应 +#[derive(Clone, Debug)] +pub struct MqttResponse { + state: Arc>, + response: Arc>, +} + +impl MqttResponse { + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(MqttResponseState::Waiting)), + response: Arc::new(Mutex::new(Bytes::new())), + } + } + + pub fn is_waiting(&self) -> bool { + *self.state.lock().unwrap() == MqttResponseState::Waiting + } + + pub fn is_received(&self) -> bool { + *self.state.lock().unwrap() == MqttResponseState::Received + } + + pub fn is_timeout(&self) -> bool { + *self.state.lock().unwrap() == MqttResponseState::Timeout + } + + pub fn set_timeout(&self) { + *self.state.lock().unwrap() = MqttResponseState::Timeout; + } + + pub fn set(&self, response: Bytes) { + *self.state.lock().unwrap() = MqttResponseState::Received; + *self.response.lock().unwrap() = response; + } + + pub fn get(&self) -> Bytes { + self.response.lock().unwrap().clone() + } +} + +/// MQTT响应Future +#[derive(Clone)] +pub struct MqttResponseFuture { + pub start_time: std::time::Instant, + timeout: Duration, + tx: Arc>>>, + waker: Arc>>, + response_topic: String, + response: MqttResponse, +} + +impl MqttResponseFuture { + pub fn new(response_topic: &str, timeout: Duration) -> Self { + let (tx, rx) = oneshot::channel(); + let r = Self { + start_time: std::time::Instant::now(), + timeout, + tx: Arc::new(Mutex::new(Some(tx))), + waker: Arc::new(Mutex::new(None)), + response_topic: response_topic.to_string(), + response: MqttResponse::new(), + }; + // 启动超时检查 + r.start_timeout_monitor(rx); + + r + } + + /// 启动超时监控任务逻辑 + fn start_timeout_monitor(&self, rx: oneshot::Receiver<()>) { + let response = self.response.clone(); + let response_topic = self.response_topic.clone(); + let duration = self.timeout.clone(); + let waker = self.waker.clone(); + tokio::spawn(async move { + if let Err(_) = timeout(duration, rx).await { + error!("Mqtt response timeout: {:?}", response_topic); + response.set_timeout(); + if let Some(waker) = waker.lock().unwrap().take() { + waker.wake(); + } + } + }); + } +} + +#[async_trait::async_trait] +impl MessageHandler for MqttResponseFuture { + async fn handle(&self, publish: Publish) { + if publish.topic == self.response_topic { + self.response.set(publish.payload); + if let Some(tx) = self.tx.lock().unwrap().take() { + tx.send(()) + .expect("Send Mqtt response timeout signal failed"); + } + if let Some(waker) = self.waker.lock().unwrap().take() { + waker.wake(); + } + } + } +} + +impl std::future::Future for MqttResponseFuture { + type Output = MqttResponse; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + if self.response.is_waiting() { + debug!("Response future poll waiting..."); + self.waker.lock().unwrap().replace(cx.waker().clone()); + std::task::Poll::Pending + } else { + debug!("Response future poll ready: {:?}", self.response.get()); + std::task::Poll::Ready(self.response.clone()) + } + } +} + +pub fn get_publish_response_topic(publish: Option) -> Option { + publish.map(|p| p.response_topic.clone()).flatten() +} + +#[cfg(test)] +mod tests { + use super::*; + use rtss_log::tracing::{info, Level}; + + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn test_subscribe_and_publish() { + rtss_log::Logging::default().with_level(Level::DEBUG).init(); + let client = MqttClientOptions::new("rtss_test1", "tcp://localhost:1883") + .set_credentials("rtss_simulation", "Joylink@0503") + .build(); + + client + .subscribe("test/topic", QoS::AtMostOnce) + .await + .unwrap(); + let handler1 = client.register_topic_handler("test/topic", |publish: Publish| { + info!( + "Handler 1 received: topic={}, payload={:?}", + String::from_utf8_lossy(&publish.topic), + String::from_utf8_lossy(&publish.payload) + ); + }); + let h2 = client.register_topic_handler("test/topic", |publish: Publish| { + info!( + "Handler 2 received: topic={}, payload={:?}", + String::from_utf8_lossy(&publish.topic), + String::from_utf8_lossy(&publish.payload) + ); + }); + + assert_eq!(client.topic_handler_count("test/topic"), 2); + client + .publish("test/topic", QoS::AtMostOnce, b"Hello, MQTT!".to_vec()) + .await + .unwrap(); + + // Wait for a moment to allow handlers to process the message + sleep(Duration::from_millis(200)).await; + + // Test remove_handler + client.unsubscribe("test/topic").await.unwrap(); + handler1.unregister(); + assert_eq!(client.topic_handler_count("test/topic"), 1); + + h2.unregister(); + assert_eq!(client.topic_handler_count("test/topic"), 0); + + // Test unsubscribe + client.close().await.unwrap(); + } + + #[tokio::test] + async fn test_request() { + rtss_log::Logging::default().with_level(Level::DEBUG).init(); + let client = MqttClientOptions::new("rtss_test1", "tcp://localhost:1883") + .set_credentials("rtss_simulation", "Joylink@0503") + .build(); + set_global_mqtt_client(client.clone()).await.unwrap(); + + if let Some(c) = get_global_mqtt_client().await { + c.subscribe("test/request", QoS::AtMostOnce).await.unwrap(); + let handler = |p: Publish| { + info!( + "Request handler received: topic={}, payload={:?}", + String::from_utf8_lossy(&p.topic), + String::from_utf8_lossy(&p.payload) + ); + let response = Bytes::from("Hello, response!"); + let resp_topic = get_publish_response_topic(p.properties.clone()); + if let Some(r_topic) = resp_topic { + tokio::spawn(async move { + if let Some(c) = get_global_mqtt_client().await { + c.publish(&r_topic, QoS::AtMostOnce, response.to_vec()) + .await + .unwrap(); + } + }); + } + }; + let _ = c.register_topic_handler("test/request", handler); + } + + if let Some(c) = get_global_mqtt_client().await { + let response = c + .request("test/request", QoS::AtMostOnce, b"Hello, request!".to_vec()) + .await + .unwrap(); + info!("Request response: {:?}", response); + } + + client.close().await.unwrap(); + } +} diff --git a/migrations/20240830095636_init.up.sql b/migrations/20240830095636_init.up.sql index 3e5330c..0191bd9 100644 --- a/migrations/20240830095636_init.up.sql +++ b/migrations/20240830095636_init.up.sql @@ -1,6 +1,9 @@ -- 初始化数据库SCHEMA(所有轨道交通信号系统仿真的表、类型等都在rtss SCHEMA下) CREATE SCHEMA rtss; +-- 创建mqtt客户端id序列 +CREATE SEQUENCE rtss.mqtt_client_id_seq; + -- 创建用户表 CREATE TABLE rtss.user (