From f70e08e5e89528c4e044959607df7b4e9edee5fa Mon Sep 17 00:00:00 2001 From: soul-walker <31162815+soul-walker@users.noreply.github.com> Date: Wed, 25 Sep 2024 18:35:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=94=A8=E6=88=B7=E8=A1=A8?= =?UTF-8?q?=20=E6=B7=BB=E5=8A=A0=E7=94=A8=E6=88=B7=E5=90=8C=E6=AD=A5?= =?UTF-8?q?=EF=BC=8C=E6=9F=A5=E8=AF=A2=E7=AD=89=E6=8E=A5=E5=8F=A3=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=8E=B7=E5=8F=96jwt=E6=8E=A5=E5=8F=A3=20?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E8=8D=89=E7=A8=BF=E6=95=B0=E6=8D=AE=E3=80=81?= =?UTF-8?q?=E5=8F=91=E5=B8=83=E6=95=B0=E6=8D=AE=E7=94=A8=E6=88=B7name?= =?UTF-8?q?=E5=85=B3=E8=81=94=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 100 +++++ conf/default.toml | 1 + conf/dev.toml | 2 +- crates/rtss_api/Cargo.toml | 2 +- crates/rtss_api/src/apis/draft_data.rs | 45 +-- crates/rtss_api/src/apis/mod.rs | 7 +- crates/rtss_api/src/apis/release_data.rs | 11 +- crates/rtss_api/src/apis/user.rs | 162 ++++++++ crates/rtss_api/src/jwt_auth.rs | 82 ---- crates/rtss_api/src/lib.rs | 1 + crates/rtss_api/src/loader/mod.rs | 10 + crates/rtss_api/src/server.rs | 21 +- crates/rtss_api/src/user_auth/jwt_auth.rs | 87 +++++ crates/rtss_api/src/user_auth/mod.rs | 87 ++++- crates/rtss_db/src/db_access/draft_data.rs | 31 +- crates/rtss_db/src/db_access/mod.rs | 2 + crates/rtss_db/src/db_access/release_data.rs | 25 +- crates/rtss_db/src/db_access/user.rs | 380 +++++++++++++++++++ crates/rtss_db/src/model.rs | 43 +++ migrations/20240830095636_init.up.sql | 25 +- src/app_config.rs | 1 + src/cmd.rs | 1 + 22 files changed, 985 insertions(+), 141 deletions(-) create mode 100644 crates/rtss_api/src/apis/user.rs delete mode 100644 crates/rtss_api/src/jwt_auth.rs create mode 100644 crates/rtss_api/src/loader/mod.rs create mode 100644 crates/rtss_api/src/user_auth/jwt_auth.rs create mode 100644 crates/rtss_db/src/db_access/user.rs diff --git a/Cargo.lock b/Cargo.lock index 77d56be..831b6e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -917,6 +917,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.10.7" @@ -1523,6 +1532,21 @@ dependencies = [ "serde", ] +[[package]] +name = "jsonwebtoken" +version = "9.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ae10193d25051e74945f1ea2d0b42e03cc3b890f7e4cc5faa44997d808193f" +dependencies = [ + "base64 0.21.7", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1714,6 +1738,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -1731,6 +1765,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -1833,6 +1873,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" +[[package]] +name = "pem" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e459365e590736a54c3fa561947c84837534b8e9af6fc5bf781307e82658fae" +dependencies = [ + "base64 0.22.1", + "serde", +] + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -1962,6 +2012,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -2322,6 +2378,7 @@ dependencies = [ "base64 0.22.1", "bevy_ecs", "chrono", + "jsonwebtoken", "reqwest", "rtss_db", "rtss_dto", @@ -2629,6 +2686,18 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time", +] + [[package]] name = "slab" version = "0.4.9" @@ -3031,6 +3100,37 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tiny-keccak" version = "2.0.2" diff --git a/conf/default.toml b/conf/default.toml index 58b6188..1a2dd81 100644 --- a/conf/default.toml +++ b/conf/default.toml @@ -11,3 +11,4 @@ base_url = "http://localhost:8080" login_url = "/api/login" logout_url = "/api/login/logout" user_info_url = "/api/login/getUserInfo" +sync_user_url = "/api/userinfo/list/all" diff --git a/conf/dev.toml b/conf/dev.toml index 5490cc1..839879b 100644 --- a/conf/dev.toml +++ b/conf/dev.toml @@ -2,4 +2,4 @@ url = "postgresql://joylink:Joylink@0503@localhost:5432/joylink" [sso] -base_url = "https://joylink.club/jlcloud" +base_url = "http://192.168.33.233/rtss-server" diff --git a/crates/rtss_api/Cargo.toml b/crates/rtss_api/Cargo.toml index e66797d..bbc8cef 100644 --- a/crates/rtss_api/Cargo.toml +++ b/crates/rtss_api/Cargo.toml @@ -11,7 +11,7 @@ serde_json = { workspace = true } chrono = { version = "0.4.38", features = ["serde"] } axum = "0.7.5" axum-extra = { version = "0.9.3", features = ["typed-header"] } -# jsonwebtoken = "9.3.0" +jsonwebtoken = "9.3.0" tower-http = { version = "0.6.0", features = ["cors"] } async-graphql = { version = "7.0.7", features = ["chrono", "dataloader"] } async-graphql-axum = "7.0.6" diff --git a/crates/rtss_api/src/apis/draft_data.rs b/crates/rtss_api/src/apis/draft_data.rs index a5f96d4..e811f96 100644 --- a/crates/rtss_api/src/apis/draft_data.rs +++ b/crates/rtss_api/src/apis/draft_data.rs @@ -9,10 +9,11 @@ use rtss_dto::common::DataType; use serde_json::Value; use crate::apis::{PageDto, PageQueryDto}; +use crate::loader::RtssDbLoader; use super::common::{DataOptions, IscsDataOptions}; use super::release_data::ReleaseDataId; -use crate::RtssDbLoader; +use super::user::UserId; use crate::user_auth::{Role, RoleGuard, Token, UserAuthCache}; @@ -30,7 +31,7 @@ impl DraftDataQuery { &self, ctx: &Context<'ctx>, paging: PageQueryDto, - query: DraftDataFilterDto, + query: DraftDataFilterDto, ) -> async_graphql::Result> { let db_accessor = ctx.data::()?; let paging_result = db_accessor @@ -64,7 +65,7 @@ impl DraftDataQuery { &self, ctx: &Context<'ctx>, paging: PageQueryDto, - mut query: SharedDraftDataFilterDto, + mut query: DraftDataFilterDto, ) -> async_graphql::Result> { let db_accessor = ctx.data::()?; query.data_type = Some(DataType::Iscs); @@ -258,9 +259,11 @@ impl From> for rtss_db::DraftDataQuery /// 共享的草稿数据查询条件 #[derive(Debug, InputObject)] +#[graphql(concrete(name = "DraftDataFilterDto", params(Value)))] #[graphql(concrete(name = "SharedDraftIscsDataFilterDto", params(IscsDataOptions)))] -pub struct SharedDraftDataFilterDto { +pub struct DraftDataFilterDto { + #[graphql(skip)] pub user_id: Option, pub name: Option, /// 数据类型,在某个具体类型查询时不传,传了也不生效 @@ -268,8 +271,8 @@ pub struct SharedDraftDataFilterDto { pub options: Option, } -impl From> for rtss_db::DraftDataQuery { - fn from(value: SharedDraftDataFilterDto) -> Self { +impl From> for rtss_db::DraftDataQuery { + fn from(value: DraftDataFilterDto) -> Self { Self { user_id: value.user_id, name: value.name, @@ -280,28 +283,6 @@ impl From> for rtss_db::DraftDataQue } } -/// 草稿数据查询条件 -#[derive(Debug, InputObject)] -pub struct DraftDataFilterDto { - pub user_id: Option, - pub name: Option, - pub data_type: Option, - pub options: Option, - pub is_shared: Option, -} - -impl From for rtss_db::DraftDataQuery { - fn from(value: DraftDataFilterDto) -> Self { - Self { - user_id: value.user_id, - name: value.name, - data_type: value.data_type, - is_shared: value.is_shared, - options: value.options, - } - } -} - #[derive(Debug, SimpleObject)] #[graphql(complex)] pub struct DraftDataDto { @@ -320,6 +301,7 @@ pub struct DraftDataDto { #[ComplexObject] impl DraftDataDto { + /// 获取默认发布数据name async fn default_release_data_name( &self, ctx: &Context<'_>, @@ -332,6 +314,13 @@ impl DraftDataDto { Ok(None) } } + + /// 获取用户name + async fn user_name(&self, ctx: &Context<'_>) -> async_graphql::Result> { + let loader = ctx.data_unchecked::>(); + let name = loader.load_one(UserId::new(self.user_id)).await?; + Ok(name) + } } impl From for DraftDataDto { diff --git a/crates/rtss_api/src/apis/mod.rs b/crates/rtss_api/src/apis/mod.rs index 154b39b..96f68be 100644 --- a/crates/rtss_api/src/apis/mod.rs +++ b/crates/rtss_api/src/apis/mod.rs @@ -5,17 +5,19 @@ use release_data::{ReleaseDataMutation, ReleaseDataQuery}; mod simulation_definition; mod sys_info; use simulation_definition::*; +use user::{UserMutation, UserQuery}; mod common; mod draft_data; mod release_data; mod simulation; +mod user; #[derive(Default, MergedObject)] -pub struct Query(DraftDataQuery, ReleaseDataQuery); +pub struct Query(UserQuery, DraftDataQuery, ReleaseDataQuery); #[derive(Default, MergedObject)] -pub struct Mutation(DraftDataMutation, ReleaseDataMutation); +pub struct Mutation(UserMutation, DraftDataMutation, ReleaseDataMutation); #[derive(Enum, Copy, Clone, Default, Eq, PartialEq, Debug)] #[graphql(remote = "rtss_db::common::SortOrder")] @@ -43,6 +45,7 @@ impl From for rtss_db::common::PageQuery { } #[derive(Debug, SimpleObject)] +#[graphql(concrete(name = "UserPageDto", params(user::UserDto)))] #[graphql(concrete(name = "DraftDataPageDto", params(draft_data::DraftDataDto)))] #[graphql(concrete(name = "DraftIscsDataPageDto", params(draft_data::DraftIscsDataDto)))] #[graphql(concrete(name = "ReleaseDataPageDto", params(release_data::ReleaseDataDto)))] diff --git a/crates/rtss_api/src/apis/release_data.rs b/crates/rtss_api/src/apis/release_data.rs index 7c8c458..e8da824 100644 --- a/crates/rtss_api/src/apis/release_data.rs +++ b/crates/rtss_api/src/apis/release_data.rs @@ -12,9 +12,10 @@ use rtss_dto::common::DataType; use serde_json::Value; use crate::apis::draft_data::DraftDataDto; -use crate::RtssDbLoader; +use crate::loader::RtssDbLoader; use super::common::{DataOptions, IscsDataOptions}; +use super::user::UserId; use super::{PageDto, PageQueryDto}; use crate::user_auth::{Role, RoleGuard, Token, UserAuthCache}; @@ -225,6 +226,7 @@ impl ReleaseDataMutation { #[graphql(concrete(name = "ReleaseDataFilterDto", params(Value)))] #[graphql(concrete(name = "ReleaseIscsDataFilterDto", params(IscsDataOptions)))] pub struct ReleaseTypedDataFilterDto { + #[graphql(skip)] pub user_id: Option, pub name: Option, /// 数据类型,在某个具体类型查询时不传,传了也不生效 @@ -272,6 +274,13 @@ impl ReleaseDataDto { Ok(None) } } + + /// 获取用户name + async fn user_name(&self, ctx: &Context<'_>) -> async_graphql::Result> { + let loader = ctx.data_unchecked::>(); + let name = loader.load_one(UserId::new(self.user_id)).await?; + Ok(name) + } } #[derive(Clone, Copy, Hash, PartialEq, Eq)] diff --git a/crates/rtss_api/src/apis/user.rs b/crates/rtss_api/src/apis/user.rs new file mode 100644 index 0000000..4e6ee14 --- /dev/null +++ b/crates/rtss_api/src/apis/user.rs @@ -0,0 +1,162 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_graphql::{dataloader::Loader, Context, InputObject, Object, SimpleObject}; +use chrono::NaiveDateTime; +use rtss_db::{DbAccessError, RtssDbAccessor, UserAccessor}; + +use crate::{ + loader::RtssDbLoader, + user_auth::{build_jwt, Claims, Role, RoleGuard, Token, UserAuthCache, UserInfoDto}, + UserAuthClient, +}; + +use super::{PageDto, PageQueryDto}; + +#[derive(Default)] +pub struct UserQuery; + +#[Object] +impl UserQuery { + /// 获取用户信息 + #[graphql(guard = "RoleGuard::new(Role::User)")] + async fn login_user_info(&self, ctx: &Context<'_>) -> async_graphql::Result { + let user = ctx + .data::()? + .query_user(&ctx.data::()?.0) + .await?; + Ok(user.into()) + } + + /// 获取jwt令牌(mqtt验证) + #[graphql(guard = "RoleGuard::new(Role::User)")] + async fn get_jwt(&self, ctx: &Context<'_>) -> async_graphql::Result { + let user = ctx + .data::()? + .query_user(&ctx.data::()?.0) + .await?; + let jwt = build_jwt(Claims::new(user.id_i32()))?; + Ok(jwt.0) + } + + /// 分页查询用户(系统管理) + #[graphql(guard = "RoleGuard::new(Role::Admin)")] + async fn user_paging( + &self, + ctx: &Context<'_>, + page: PageQueryDto, + query: UserQueryDto, + ) -> async_graphql::Result> { + let dba = ctx.data::()?; + let paging = dba.query_user_page(page.into(), query.into()).await?; + Ok(paging.into()) + } +} + +#[derive(Default)] +pub struct UserMutation; + +#[Object] +impl UserMutation { + /// 同步用户 + #[graphql(guard = "RoleGuard::new(Role::Admin)")] + async fn sync_user(&self, ctx: &Context<'_>) -> async_graphql::Result { + let http_client = ctx.data::()?; + let users = http_client.query_all_users(ctx.data::()?).await?; + let dba = ctx.data::()?; + dba.sync_user( + users + .into_iter() + .map(|u| u.into()) + .collect::>() + .as_slice(), + ) + .await?; + Ok(true) + } +} + +#[derive(Debug, InputObject)] +pub struct UserQueryDto { + pub id: Option, + pub name: Option, + pub email: Option, + pub mobile: Option, + pub roles: Option>, +} + +impl From for rtss_db::UserPageFilter { + fn from(value: UserQueryDto) -> Self { + Self { + id: value.id, + name: value.name, + email: value.email, + mobile: value.mobile, + roles: value.roles.map(|r| serde_json::to_value(r).unwrap()), + } + } +} + +#[derive(Debug, SimpleObject)] +pub struct UserDto { + pub id: i32, + pub name: String, + pub mobile: Option, + pub email: Option, + pub roles: Vec, + pub created_at: NaiveDateTime, + pub updated_at: NaiveDateTime, +} + +impl From for UserDto { + fn from(value: UserInfoDto) -> Self { + Self { + id: value.id_i32(), + name: value.name(), + mobile: value.mobile.clone(), + email: value.email.clone(), + roles: value.roles(), + created_at: value.created_at().naive_local(), + updated_at: value.updated_at().naive_local(), + } + } +} + +impl From for UserDto { + fn from(value: rtss_db::model::UserModel) -> Self { + Self { + id: value.id, + name: value.username, + mobile: value.mobile, + email: value.email, + roles: serde_json::from_value(value.roles).unwrap(), + created_at: value.created_at.naive_local(), + updated_at: value.updated_at.naive_local(), + } + } +} + +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +pub struct UserId { + pub id: i32, +} + +impl UserId { + pub fn new(id: i32) -> Self { + Self { id } + } +} + +impl Loader for RtssDbLoader { + type Value = String; + type Error = Arc; + + async fn load(&self, keys: &[UserId]) -> Result, Self::Error> { + let ids: Vec = keys.iter().map(|k| k.id).collect(); + let rows = self.db_accessor.query_user_name(ids.as_slice()).await?; + let map: HashMap = rows + .into_iter() + .map(|row| (UserId::new(row.0), row.1)) + .collect(); + Ok(map) + } +} diff --git a/crates/rtss_api/src/jwt_auth.rs b/crates/rtss_api/src/jwt_auth.rs deleted file mode 100644 index d667ef0..0000000 --- a/crates/rtss_api/src/jwt_auth.rs +++ /dev/null @@ -1,82 +0,0 @@ -// use std::sync::LazyLock; - -// use async_graphql::Result; -// use axum::http::HeaderMap; -// use jsonwebtoken::{decode, DecodingKey, Validation}; -// use rtss_log::tracing::error; -// use serde::{Deserialize, Serialize}; - -// static KEYS: LazyLock = LazyLock::new(|| { -// // let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); -// let secret = "joylink".to_string(); -// Keys::new(secret.as_bytes()) -// }); - -// struct Keys { -// // encoding: EncodingKey, -// decoding: DecodingKey, -// } - -// impl Keys { -// pub fn new(secret: &[u8]) -> Self { -// Self { -// // encoding: EncodingKey::from_secret(secret), -// decoding: DecodingKey::from_secret(secret), -// } -// } -// } - -// #[derive(Debug)] -// pub enum AuthError { -// InvalidToken, -// } - -// pub(crate) fn get_token_from_headers(headers: HeaderMap) -> Result, AuthError> { -// let option_token = headers.get("Token"); -// if let Some(token) = option_token { -// let token_data = decode::( -// token.to_str().unwrap(), -// &KEYS.decoding, -// &Validation::default(), -// ) -// .map_err(|err| { -// error!("Error decoding token: {:?}", err); -// AuthError::InvalidToken -// })?; -// Ok(Some(token_data.claims)) -// } else { -// Ok(None) -// } -// } - -// #[derive(Debug, Serialize, Deserialize)] -// pub struct Claims { -// pub id: u32, -// pub sub: String, -// } - -// #[cfg(test)] -// mod tests { - -// use super::*; - -// #[test] -// fn test_get_token_from_headers() { -// rtss_log::Logging::default().init(); -// let mut headers: HeaderMap = HeaderMap::new(); -// headers.insert("Token", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MjQ2NzAyMjcsImlkIjo2LCJvcmlnX2lhdCI6MTcyNDIzODIyNywic3ViIjoiNiJ9.sSfjdW7d3OqOE6G1p47c4dcCan4evRGoNjGPUyVfWLk".parse().unwrap()); -// let result = get_token_from_headers(headers); -// match result { -// Ok(Some(claims)) => { -// assert_eq!(claims.id, 6); -// assert_eq!(claims.sub, "6"); -// } -// Ok(None) => { -// panic!("Expected Some(claims), got None"); -// } -// Err(e) => { -// panic!("Error: {:?}", e); -// } -// } -// } -// } diff --git a/crates/rtss_api/src/lib.rs b/crates/rtss_api/src/lib.rs index 9f9be01..37142f1 100644 --- a/crates/rtss_api/src/lib.rs +++ b/crates/rtss_api/src/lib.rs @@ -1,5 +1,6 @@ // mod jwt_auth; mod apis; +mod loader; mod server; mod user_auth; diff --git a/crates/rtss_api/src/loader/mod.rs b/crates/rtss_api/src/loader/mod.rs new file mode 100644 index 0000000..693c098 --- /dev/null +++ b/crates/rtss_api/src/loader/mod.rs @@ -0,0 +1,10 @@ +/// 数据库加载器 +pub struct RtssDbLoader { + pub(crate) db_accessor: rtss_db::RtssDbAccessor, +} + +impl RtssDbLoader { + pub fn new(db_accessor: rtss_db::RtssDbAccessor) -> Self { + Self { db_accessor } + } +} diff --git a/crates/rtss_api/src/server.rs b/crates/rtss_api/src/server.rs index f861e6b..dbf7e13 100644 --- a/crates/rtss_api/src/server.rs +++ b/crates/rtss_api/src/server.rs @@ -15,6 +15,7 @@ use tokio::net::TcpListener; use tower_http::cors::CorsLayer; use crate::apis::{Mutation, Query}; +use crate::loader::RtssDbLoader; use crate::user_auth; pub use crate::user_auth::UserAuthClient; @@ -85,27 +86,17 @@ async fn graphiql() -> impl IntoResponse { Html(playground_source(GraphQLPlaygroundConfig::new("/"))) } -pub struct RtssDbLoader { - pub(crate) db_accessor: rtss_db::RtssDbAccessor, -} - -impl RtssDbLoader { - pub fn new(db_accessor: rtss_db::RtssDbAccessor) -> Self { - Self { db_accessor } - } -} - pub type RtssAppSchema = Schema; pub async fn new_schema(config: ServerConfig) -> RtssAppSchema { - let user_info_cache = crate::user_auth::UserAuthCache::new( - config - .user_auth_client - .expect("user auth client not configured"), - ); + let client = config + .user_auth_client + .expect("user auth client not configured"); + let user_info_cache = crate::user_auth::UserAuthCache::new(client.clone()); let dba = rtss_db::get_db_accessor(&config.database_url).await; let loader = RtssDbLoader::new(dba.clone()); Schema::build(Query::default(), Mutation::default(), EmptySubscription) + .data(client) .data(user_info_cache) .data(dba) .data(DataLoader::new(loader, tokio::spawn)) diff --git a/crates/rtss_api/src/user_auth/jwt_auth.rs b/crates/rtss_api/src/user_auth/jwt_auth.rs new file mode 100644 index 0000000..b0d144c --- /dev/null +++ b/crates/rtss_api/src/user_auth/jwt_auth.rs @@ -0,0 +1,87 @@ +use std::sync::LazyLock; + +use async_graphql::Result; +use jsonwebtoken::{decode, DecodingKey, EncodingKey, Validation}; +use serde::{Deserialize, Serialize}; + +static KEYS: LazyLock = LazyLock::new(|| { + // let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); + let secret = "joylink".to_string(); + Keys::new(secret.as_bytes()) +}); + +struct Keys { + encoding: EncodingKey, + decoding: DecodingKey, +} + +impl Keys { + pub fn new(secret: &[u8]) -> Self { + Self { + encoding: EncodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret), + } + } +} + +#[derive(Debug)] +pub struct Jwt(pub String); + +#[derive(Debug, Serialize, Deserialize)] +pub struct Claims { + pub id: i32, + exp: usize, // 过期时间,单位秒 +} + +pub fn get_current_timestamp() -> u64 { + let start = std::time::SystemTime::now(); + start + .duration_since(std::time::UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() +} + +impl Claims { + pub fn new(id: i32) -> Self { + Self { + id, + exp: get_current_timestamp() as usize + 3600 * 24 * 7, // 7天 + } + } +} + +/// 构建jwt +pub fn build_jwt(claims: Claims) -> Result { + let token = jsonwebtoken::encode(&jsonwebtoken::Header::default(), &claims, &KEYS.encoding)?; + Ok(Jwt(token)) +} + +/// 解析jwt +#[allow(dead_code)] +pub fn decode_jwt(token: &str) -> Result { + let data = decode::(token, &KEYS.decoding, &Validation::default())?; + Ok(data.claims) +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_jwt() { + rtss_log::Logging::default().init(); + let claim = Claims::new(5); + let jwt = build_jwt(claim).unwrap(); + println!("jwt: {}", jwt.0); + let result = decode_jwt(&jwt.0); + match result { + Ok(claims) => { + assert_eq!(claims.id, 5); + } + Err(e) => { + panic!("Error: {:?}", e); + } + } + } +} diff --git a/crates/rtss_api/src/user_auth/mod.rs b/crates/rtss_api/src/user_auth/mod.rs index 4e66601..cf14a50 100644 --- a/crates/rtss_api/src/user_auth/mod.rs +++ b/crates/rtss_api/src/user_auth/mod.rs @@ -3,12 +3,16 @@ use std::{ sync::{Arc, Mutex}, }; -use async_graphql::Guard; +use async_graphql::{Enum, Guard}; use axum::http::HeaderMap; +use chrono::{DateTime, Local}; use rtss_log::tracing::error; use serde::{Deserialize, Serialize}; -#[derive(Eq, PartialEq, Clone, Copy, Hash)] +mod jwt_auth; +pub use jwt_auth::*; + +#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash, Enum, Serialize, Deserialize)] pub enum Role { Admin, User, @@ -71,6 +75,7 @@ pub struct UserAuthClient { pub login_url: String, pub logout_url: String, pub user_info_url: String, + pub sync_user_url: String, } impl UserAuthClient { @@ -120,6 +125,21 @@ impl UserAuthClient { let user_info = serde_json::from_value(common.data.unwrap())?; Ok(user_info) } + + pub async fn query_all_users(&self, token: &Token) -> anyhow::Result> { + let url = format!("{}{}", self.base_url, self.sync_user_url); + let response = reqwest::Client::new() + .get(&url) + .header("X-Token", &token.0) + .send() + .await?; + let common = response.json::().await?; + if common.code != 200 { + return Err(anyhow::anyhow!(common.message)); + } + let user_info_list = serde_json::from_value(common.data.unwrap())?; + Ok(user_info_list) + } } #[derive(Debug, Serialize, Deserialize)] @@ -157,6 +177,12 @@ pub struct UserInfoDto { pub name: Option, pub nickname: Option, pub roles: Vec, + pub mobile: Option, + pub email: Option, + #[serde(rename = "createTime")] + pub create_time: String, + #[serde(rename = "updateTime")] + pub update_time: Option, } impl UserInfoDto { pub fn id_i32(&self) -> i32 { @@ -165,6 +191,13 @@ impl UserInfoDto { .expect("parse UserInfoDto.id to i32 failed") } + pub fn name(&self) -> String { + self.name + .clone() + .or(self.nickname.clone()) + .unwrap_or_default() + } + pub fn roles(&self) -> Vec { let mut unique_roles = HashSet::new(); for role in &self.roles { @@ -180,6 +213,36 @@ impl UserInfoDto { } unique_roles.into_iter().collect() } + + pub fn created_at(&self) -> DateTime { + parse_to_date_time(&self.create_time) + } + + pub fn updated_at(&self) -> DateTime { + parse_to_date_time(self.update_time.as_deref().unwrap_or(&self.create_time)) + } +} + +fn parse_to_date_time(s: &str) -> chrono::DateTime { + chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S") + .expect("parse date_time failed") + .and_local_timezone(Local) + .unwrap() +} + +impl From for rtss_db::SyncUserInfo { + fn from(user_info: UserInfoDto) -> Self { + Self { + id: user_info.id_i32(), + name: user_info.name().replace("'", ""), // 需要处理name中带“'”字符的情况 + password: "".to_string(), // 暂时先不同步 + email: user_info.email.clone(), + mobile: user_info.mobile.clone(), + roles: serde_json::to_value(user_info.roles()).unwrap(), + created_at: parse_to_date_time(&user_info.create_time), + updated_at: user_info.update_time.map(|s| parse_to_date_time(&s)), + } + } } impl UserAuthCache { @@ -219,6 +282,7 @@ impl UserAuthCache { #[cfg(test)] mod tests { use anyhow::Ok; + use rtss_log::tracing::Level; use super::*; @@ -236,19 +300,32 @@ mod tests { assert_eq!(login_info.account, "17791995809"); } + #[test] + fn test_chrono_datetime_parse() { + let time_str = "2021-08-31 10:00:00"; + let dt = parse_to_date_time(time_str); + println!("{:?}", dt); + } + #[tokio::test] async fn test_user_auth_cache() -> anyhow::Result<()> { rtss_log::Logging::default().with_level(Level::DEBUG).init(); - let cache = UserAuthCache::new(UserAuthClient { - base_url: "https://joylink.club/jlcloud".to_string(), + let client = UserAuthClient { + base_url: "http://192.168.33.233/rtss-server".to_string(), login_url: "/api/login".to_string(), logout_url: "/api/login/logout".to_string(), user_info_url: "/api/login/getUserInfo".to_string(), - }); + sync_user_url: "/api/userinfo/list/all".to_string(), + }; + let cache = UserAuthCache::new(client.clone()); let token = cache.client.login(LoginInfo::default()).await?; let user = cache.query_user(&token).await?; println!("token: {}, {:?}", token, user); assert_eq!(cache.len(), 1); + + let user_list = client.query_all_users(&Token(token)).await?; + println!("{:?}", user_list); + Ok(()) } } diff --git a/crates/rtss_db/src/db_access/draft_data.rs b/crates/rtss_db/src/db_access/draft_data.rs index c060877..fc07551 100644 --- a/crates/rtss_db/src/db_access/draft_data.rs +++ b/crates/rtss_db/src/db_access/draft_data.rs @@ -426,22 +426,47 @@ impl DraftDataAccessor for RtssDbAccessor { #[cfg(test)] mod tests { + use crate::{SyncUserInfo, UserAccessor}; + use super::*; use rtss_dto::common::IscsStyle; use rtss_log::tracing::Level; use serde::{Deserialize, Serialize}; - use sqlx::PgPool; + use sqlx::{types::chrono::Local, PgPool}; #[derive(Debug, Serialize, Deserialize)] pub struct IscsDataOptions { pub style: IscsStyle, } + #[derive(Debug, Clone, Serialize, Deserialize)] + enum Role { + User, + Admin, + } + // You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here. #[sqlx::test(migrator = "crate::MIGRATOR")] async fn basic_use_test(pool: PgPool) -> Result<(), DbAccessError> { rtss_log::Logging::default().with_level(Level::DEBUG).init(); let accessor = crate::db_access::RtssDbAccessor::new(pool); + // 同步10个用户 + let mut users = vec![]; + for i in 0..10 { + let user = SyncUserInfo { + id: i + 1, + name: format!("user{}", i + 1), + password: "password".to_string(), + roles: serde_json::to_value(&vec![Role::User]).unwrap(), + email: None, + mobile: None, + created_at: Local::now(), + updated_at: None, + }; + users.push(user); + } + accessor.sync_user(users.as_slice()).await?; + // 创建草稿数据测试 let res = accessor .create_draft_data(CreateDraftData::new("test", DataType::Em, 10)) @@ -487,10 +512,10 @@ mod tests { assert!(get_by_id.is_shared); // save as new draft测试 - let new_draft = accessor.save_as_new_draft(res.id, "new draft", 11).await?; + let new_draft = accessor.save_as_new_draft(res.id, "new draft", 9).await?; println!("{:?}", new_draft); assert_eq!(new_draft.name, "new draft"); - assert_eq!(new_draft.user_id, 11); + assert_eq!(new_draft.user_id, 9); assert_eq!(new_draft.options, res.options); assert_eq!(new_draft.data.unwrap(), data); assert_eq!( diff --git a/crates/rtss_db/src/db_access/mod.rs b/crates/rtss_db/src/db_access/mod.rs index 85bc470..d9a5956 100644 --- a/crates/rtss_db/src/db_access/mod.rs +++ b/crates/rtss_db/src/db_access/mod.rs @@ -2,6 +2,8 @@ mod draft_data; pub use draft_data::*; mod release_data; pub use release_data::*; +mod user; +pub use user::*; #[derive(Clone)] pub struct RtssDbAccessor { diff --git a/crates/rtss_db/src/db_access/release_data.rs b/crates/rtss_db/src/db_access/release_data.rs index 778c65e..153bf8d 100644 --- a/crates/rtss_db/src/db_access/release_data.rs +++ b/crates/rtss_db/src/db_access/release_data.rs @@ -645,9 +645,10 @@ impl ReleaseDataAccessor for RtssDbAccessor { #[cfg(test)] mod tests { - use crate::{CreateDraftData, DraftDataAccessor, RtssDbAccessor}; + use crate::{CreateDraftData, DraftDataAccessor, RtssDbAccessor, SyncUserInfo, UserAccessor}; use super::*; + use chrono::Local; use rtss_dto::common::IscsStyle; use rtss_log::tracing::Level; use serde::{Deserialize, Serialize}; @@ -691,11 +692,33 @@ mod tests { pub style: IscsStyle, } + #[derive(Debug, Clone, Serialize, Deserialize)] + enum Role { + User, + Admin, + } + // You could also do `use foo_crate::MIGRATOR` and just refer to it as `MIGRATOR` here. #[sqlx::test(migrator = "crate::MIGRATOR")] async fn test_basic_use(pool: PgPool) -> Result<(), DbAccessError> { rtss_log::Logging::default().with_level(Level::DEBUG).init(); let accessor = RtssDbAccessor::new(pool); + // 同步10个用户 + let mut users = vec![]; + for i in 0..10 { + let user = SyncUserInfo { + id: i + 1, + name: format!("user{}", i + 1), + password: "password".to_string(), + roles: serde_json::to_value(&vec![Role::User]).unwrap(), + email: None, + mobile: None, + created_at: Local::now(), + updated_at: None, + }; + users.push(user); + } + accessor.sync_user(users.as_slice()).await?; // 创建草稿 let data = "test".as_bytes(); let draft = accessor diff --git a/crates/rtss_db/src/db_access/user.rs b/crates/rtss_db/src/db_access/user.rs new file mode 100644 index 0000000..aa64cc8 --- /dev/null +++ b/crates/rtss_db/src/db_access/user.rs @@ -0,0 +1,380 @@ +use serde_json::Value; +use sqlx::types::chrono::{DateTime, Local}; + +use crate::{ + common::{PageQuery, PageResult, TableColumn}, + model::{UserColumn, UserModel}, + DbAccessError, +}; + +use super::RtssDbAccessor; + +/// 草稿数据管理 +#[allow(async_fn_in_trait)] +pub trait UserAccessor { + /// 同步用户数据 + async fn sync_user(&self, users: &[SyncUserInfo]) -> Result<(), DbAccessError>; + /// 根据id列表查询用户name + async fn query_user_name(&self, ids: &[i32]) -> Result, DbAccessError>; + /// 分页查询用户数据 + async fn query_user_page( + &self, + page: PageQuery, + filter: UserPageFilter, + ) -> Result, DbAccessError>; +} + +#[derive(Debug, Clone)] +pub struct UserPageFilter { + pub id: Option, + pub name: Option, + pub email: Option, + pub mobile: Option, + pub roles: Option, +} + +impl UserPageFilter { + fn to_where_clause(&self) -> String { + let mut clauses = vec![]; + let id_column = UserColumn::Id.name(); + let name_column = UserColumn::Username.name(); + let email_column = UserColumn::Email.name(); + let mobile_column = UserColumn::Mobile.name(); + let roles_column = UserColumn::Roles.name(); + if let Some(id) = self.id { + clauses.push(format!( + "{id_column} = {id}", + id_column = id_column, + id = id + )); + } + if let Some(name) = &self.name { + clauses.push(format!( + "{name_column} LIKE '%{name}%'", + name_column = name_column, + name = name + )); + } + if let Some(email) = &self.email { + clauses.push(format!( + "{email_column} LIKE '%{email}%'", + email_column = email_column, + email = email + )); + } + if let Some(mobile) = &self.mobile { + clauses.push(format!( + "{mobile_column} LIKE '%{mobile}%'", + mobile_column = mobile_column, + mobile = mobile + )); + } + if let Some(roles) = &self.roles { + clauses.push(format!( + "{roles_column} @> '{roles}'", + roles_column = roles_column, + roles = roles + )); + } + if clauses.is_empty() { + return "".to_string(); + } + clauses.join(" AND ") + } +} + +#[derive(Debug, Clone)] +pub struct SyncUserInfo { + pub id: i32, + pub name: String, + pub password: String, + pub roles: Value, + pub email: Option, + pub mobile: Option, + pub created_at: DateTime, + pub updated_at: Option>, +} + +impl RtssDbAccessor { + /// 首次同步用户数据 + async fn sync_new_users(&self, users: &[SyncUserInfo]) -> Result<(), DbAccessError> { + let table = UserColumn::Table.name(); + let id = UserColumn::Id.name(); + let username = UserColumn::Username.name(); + let password = UserColumn::Password.name(); + let email = UserColumn::Email.name(); + let mobile = UserColumn::Mobile.name(); + let roles = UserColumn::Roles.name(); + let created_at = UserColumn::CreatedAt.name(); + let updated_at = UserColumn::UpdatedAt.name(); + let insert_columns = format!( + "{id}, {username}, {password}, {email}, {mobile}, {roles}, {created_at}, {updated_at}", + id = id, + username = username, + password = password, + email = email, + mobile = mobile, + roles = roles, + created_at = created_at, + updated_at = updated_at + ); + let insert_values = users + .iter() + .map(|user| { + format!( + "({id}, '{username}', '{password}', {email}, {mobile}, '{roles}', '{created_at}', '{updated_at}')", + id = user.id, + username = user.name, + password = user.password, + email = user.email.as_deref().map(|s| format!("'{s}'")).unwrap_or("NULL".to_string()), + mobile = user.mobile.as_deref().map(|s| format!("'{s}'")).unwrap_or("NULL".to_string()), + roles = user.roles, + created_at = user.created_at, + updated_at = user.updated_at.unwrap_or(user.created_at) + ) + }) + .collect::>() + .join(", "); + let insert_clause = format!( + "INSERT INTO {table} ({insert_columns}) VALUES {insert_values}", + table = table, + insert_columns = insert_columns, + insert_values = insert_values + ); + sqlx::query(&insert_clause).execute(&self.pool).await?; + + Ok(()) + } + + /// 检查并同步用户数据 + async fn check_and_sync_user(&self, users: &[SyncUserInfo]) -> Result<(), DbAccessError> { + // 查询用户表最大的用户id + let table = UserColumn::Table.name(); + let id = UserColumn::Id.name(); + let max_id_clause = format!("SELECT MAX({id}) FROM {table}"); + let max_id: Option = sqlx::query_scalar(&max_id_clause) + .fetch_one(&self.pool) + .await?; + if max_id.is_none() { + self.sync_new_users(users).await?; + return Ok(()); + } + // 遍历用户数据,如果id大于最大id则插入,否则根据更新时间查询是否需要更新,如果需要更新则更新 + // 获取所有id大于最大id的用户数据 + let max_id = max_id.unwrap(); + let mut new_users = vec![]; + for user in users.iter() { + if user.id > max_id { + new_users.push(user.clone()); + } + } + if !new_users.is_empty() { + self.sync_new_users(new_users.as_slice()).await?; + } + // 遍历用户数据,根据更新时间查询是否需要更新,如果需要更新则更新 + for user in users.iter() { + if user.id <= max_id { + let query_clause = format!( + "SELECT {updated_at} FROM {table} WHERE {id} = {user_id}", + updated_at = UserColumn::UpdatedAt.name(), + table = table, + id = id, + user_id = user.id + ); + let updated_at: Option> = sqlx::query_scalar(&query_clause) + .fetch_optional(&self.pool) + .await?; + if let Some(updated_at) = updated_at { + if user.updated_at.unwrap_or(user.created_at) > updated_at { + let username = UserColumn::Username.name(); + let password = UserColumn::Password.name(); + let email = UserColumn::Email.name(); + let mobile = UserColumn::Mobile.name(); + let roles = UserColumn::Roles.name(); + let created_at = UserColumn::CreatedAt.name(); + let updated_at = UserColumn::UpdatedAt.name(); + let update_clause = format!( + "UPDATE {table} SET {username} = '{new_username}', {password} = '{new_password}', {email} = {new_email}, {mobile} = {new_mobile}, {roles} = '{new_roles}', {created_at} = '{new_created_at}', {updated_at} = '{new_updated_at}' WHERE {id} = {user_id}", + table = table, + username = username, + new_username = user.name, + password = password, + new_password = user.password, + email = email, + new_email = user.email.as_deref().map(|s| format!("'{s}'")).unwrap_or("NULL".to_string()), + mobile = mobile, + new_mobile = user.mobile.as_deref().map(|s| format!("'{s}'")).unwrap_or("NULL".to_string()), + roles = roles, + new_roles = user.roles, + created_at = created_at, + new_created_at = user.created_at, + updated_at = updated_at, + new_updated_at = user.updated_at.unwrap_or(user.created_at), + id = id, + user_id = user.id + ); + sqlx::query(&update_clause).execute(&self.pool).await?; + } + } + } + } + Ok(()) + } +} + +impl UserAccessor for RtssDbAccessor { + async fn sync_user(&self, users: &[SyncUserInfo]) -> Result<(), DbAccessError> { + self.check_and_sync_user(users).await + } + + async fn query_user_name(&self, ids: &[i32]) -> Result, DbAccessError> { + let table = UserColumn::Table.name(); + let id = UserColumn::Id.name(); + let username = UserColumn::Username.name(); + let select_columns = format!("{id}, {username}"); + let query_clause = format!("SELECT {select_columns} FROM {table} WHERE {id} = ANY($1)",); + let rows = sqlx::query_as::<_, (i32, String)>(&query_clause) + .bind(ids) + .fetch_all(&self.pool) + .await?; + + Ok(rows) + } + + async fn query_user_page( + &self, + page: PageQuery, + filter: UserPageFilter, + ) -> Result, DbAccessError> { + let table = UserColumn::Table.name(); + let id_column = UserColumn::Id.name(); + let where_clause = filter.to_where_clause(); + let count_clause = format!("SELECT COUNT({id_column}) FROM {table} {where_clause}"); + let total: i64 = sqlx::query_scalar(&count_clause) + .fetch_one(&self.pool) + .await?; + if total == 0 { + return Ok(PageResult::new(total, vec![])); + } + let limit_clause = page.to_limit_clause(); + let query_clause = + format!("SELECT * FROM {table} {where_clause} ORDER BY {id_column} {limit_clause}",); + let rows = sqlx::query_as::<_, UserModel>(&query_clause) + .fetch_all(&self.pool) + .await?; + Ok(PageResult::new(total, rows)) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use rtss_log::tracing::Level; + use serde::{Deserialize, Serialize}; + use sqlx::PgPool; + + use super::*; + #[derive(Debug, Clone, Serialize, Deserialize)] + enum Role { + User, + Admin, + } + + #[test] + fn test_role_value_format() { + let roles = vec![Role::User, Role::Admin]; + let value = serde_json::to_value(&roles).unwrap(); + println!("{}", value); + println!("{}", serde_json::to_string(&value).unwrap()); + } + + #[sqlx::test(migrator = "crate::MIGRATOR")] + async fn test_sync_user(pool: PgPool) -> Result<(), DbAccessError> { + // 日志初始化 + rtss_log::Logging::default().with_level(Level::DEBUG).init(); + let accessor = RtssDbAccessor::new(pool); + let users = vec![ + SyncUserInfo { + id: 1, + name: "test1".to_string(), + password: "password".to_string(), + roles: serde_json::to_value(&vec![Role::User]).unwrap(), + email: None, + mobile: None, + created_at: Local::now(), + updated_at: None, + }, + SyncUserInfo { + id: 2, + name: "test2".to_string(), + password: "password".to_string(), + roles: serde_json::to_value(&vec![Role::Admin]).unwrap(), + email: None, + mobile: None, + created_at: Local::now(), + updated_at: None, + }, + ]; + accessor.sync_user(users.as_slice()).await?; + // 分页查询检查是否插入成功 + let page = PageQuery { + page: 1, + items_per_page: 10, + }; + let filter = UserPageFilter { + id: None, + name: None, + email: None, + mobile: None, + roles: None, + }; + let page_result = accessor + .query_user_page(page.clone(), filter.clone()) + .await?; + assert_eq!(page_result.total, 2); + assert_eq!(page_result.data.len(), 2); + println!("{:?}", page_result); + // 同时新增和更新用户 + let users = vec![ + SyncUserInfo { + id: 1, + name: "test1".to_string(), + password: "password".to_string(), + roles: serde_json::to_value(&vec![Role::User]).unwrap(), + email: Some("walker@163.com".to_string()), + mobile: None, + created_at: Local::now() - Duration::from_secs(60), + updated_at: Some(Local::now()), + }, + SyncUserInfo { + id: 2, + name: "test2".to_string(), + password: "password".to_string(), + roles: serde_json::to_value(&vec![Role::Admin, Role::User]).unwrap(), + email: None, + mobile: Some("123456789".to_string()), + created_at: Local::now() - Duration::from_secs(60), + updated_at: Some(Local::now()), + }, + SyncUserInfo { + id: 3, + name: "test3".to_string(), + password: "password".to_string(), + roles: serde_json::to_value(&vec![Role::User]).unwrap(), + email: None, + mobile: None, + created_at: Local::now(), + updated_at: None, + }, + ]; + accessor.sync_user(users.as_slice()).await?; + // 分页查询检查是否更新成功 + let page_result = accessor.query_user_page(page, filter.clone()).await?; + assert_eq!(page_result.total, 3); + assert_eq!(page_result.data.len(), 3); + println!("{:?}", page_result); + + Ok(()) + } +} diff --git a/crates/rtss_db/src/model.rs b/crates/rtss_db/src/model.rs index 3c633e0..6d23b13 100644 --- a/crates/rtss_db/src/model.rs +++ b/crates/rtss_db/src/model.rs @@ -3,6 +3,33 @@ use sqlx::types::chrono::{DateTime, Local}; use crate::common::TableColumn; +#[derive(Debug)] +pub enum UserColumn { + Table, + Id, + Username, + Password, + Email, + Mobile, + Roles, + CreatedAt, + UpdatedAt, +} + +#[derive(Debug, sqlx::FromRow)] +pub struct UserModel { + pub id: i32, + pub username: String, + pub password: String, + #[sqlx(default)] + pub email: Option, + #[sqlx(default)] + pub mobile: Option, + pub roles: Value, + pub created_at: DateTime, + pub updated_at: DateTime, +} + /// 数据库表 rtss.draft_data 列映射 #[derive(Debug)] pub enum DraftDataColumn { @@ -203,6 +230,22 @@ pub struct FeatureConfigModel { pub updated_at: DateTime, } +impl TableColumn for UserColumn { + fn name(&self) -> &str { + match self { + UserColumn::Table => "rtss.user", + UserColumn::Id => "id", + UserColumn::Username => "username", + UserColumn::Password => "password", + UserColumn::Email => "email", + UserColumn::Mobile => "mobile", + UserColumn::Roles => "roles", + UserColumn::CreatedAt => "created_at", + UserColumn::UpdatedAt => "updated_at", + } + } +} + impl TableColumn for DraftDataColumn { fn name(&self) -> &str { match self { diff --git a/migrations/20240830095636_init.up.sql b/migrations/20240830095636_init.up.sql index b8798ad..6b2eda1 100644 --- a/migrations/20240830095636_init.up.sql +++ b/migrations/20240830095636_init.up.sql @@ -1,6 +1,19 @@ -- 初始化数据库SCHEMA(所有轨道交通信号系统仿真的表、类型等都在rtss SCHEMA下) CREATE SCHEMA rtss; +-- 创建用户表 +CREATE TABLE + rtss.user ( + id SERIAL PRIMARY KEY, -- id 自增主键 + username VARCHAR(128) NOT NULL, -- 用户名 + password VARCHAR(128) NOT NULL, -- 密码 + email VARCHAR(128) NULL, -- 邮箱 + mobile VARCHAR(16) NULL, -- 手机号 + roles JSONB NOT NULL DEFAULT '[]', -- 角色列表 + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP -- 更新时间 + ); + -- 创建草稿数据表 CREATE TABLE rtss.draft_data ( @@ -14,6 +27,7 @@ CREATE TABLE is_shared BOOLEAN NOT NULL DEFAULT FALSE, -- 是否共享 created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 + FOREIGN KEY (user_id) REFERENCES rtss.user (id) ON DELETE CASCADE, -- 用户外键 UNIQUE (name, user_id) -- 一个用户的草稿名称唯一 ); @@ -60,6 +74,7 @@ CREATE TABLE is_published BOOLEAN NOT NULL DEFAULT TRUE, -- 是否上架 created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 + FOREIGN KEY (user_id) REFERENCES rtss.user (id) ON DELETE CASCADE, -- 用户外键 UNIQUE(data_type, name) -- 数据类型和名称唯一 ); @@ -95,6 +110,7 @@ CREATE TABLE description TEXT NOT NULL, -- 版本描述 user_id INT NOT NULL, -- 发布用户id created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 + FOREIGN KEY (user_id) REFERENCES rtss.user (id) ON DELETE CASCADE, -- 用户外键 FOREIGN KEY (release_data_id) REFERENCES rtss.release_data (id) ON DELETE CASCADE ); @@ -133,7 +149,9 @@ CREATE TABLE creator_id INT NOT NULL, -- 创建用户id updater_id INT NOT NULL, -- 更新用户id created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 - updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP -- 更新时间 + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 + FOREIGN KEY (creator_id) REFERENCES rtss.user (id) ON DELETE CASCADE, -- 用户外键 + FOREIGN KEY (updater_id) REFERENCES rtss.user (id) ON DELETE CASCADE -- 用户外键 ); -- 注释仿真feature表 @@ -184,7 +202,9 @@ CREATE TABLE creator_id INT NOT NULL, -- 创建用户id updater_id INT NOT NULL, -- 更新用户id created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 - updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP -- 更新时间 + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 + FOREIGN KEY (creator_id) REFERENCES rtss.user (id) ON DELETE CASCADE, -- 用户外键 + FOREIGN KEY (updater_id) REFERENCES rtss.user (id) ON DELETE CASCADE -- 用户外键 ); -- 注释仿真feature group表 @@ -227,6 +247,7 @@ CREATE TABLE config BYTEA NOT NULL, -- 配置 created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 + FOREIGN KEY (user_id) REFERENCES rtss.user (id) ON DELETE CASCADE, -- 用户外键 FOREIGN KEY (feature_id) REFERENCES rtss.feature (id) ON DELETE CASCADE ); diff --git a/src/app_config.rs b/src/app_config.rs index 5ea7b7c..305d1f2 100644 --- a/src/app_config.rs +++ b/src/app_config.rs @@ -36,6 +36,7 @@ pub struct Sso { pub login_url: String, pub logout_url: String, pub user_info_url: String, + pub sync_user_url: String, } #[derive(Debug, Deserialize)] diff --git a/src/cmd.rs b/src/cmd.rs index bf2b142..17bc92f 100644 --- a/src/cmd.rs +++ b/src/cmd.rs @@ -38,6 +38,7 @@ impl CmdExecutor for ServerOpts { login_url: app_config.sso.login_url, logout_url: app_config.sso.logout_url, user_info_url: app_config.sso.user_info_url, + sync_user_url: app_config.sso.sync_user_url, }), ) .await