diff --git a/Cargo.lock b/Cargo.lock index 105a870..4cffc9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,10 +160,13 @@ dependencies = [ "chrono", "fast_chemail", "fnv", + "futures-channel", + "futures-timer", "futures-util", "handlebars", "http", "indexmap", + "lru", "mime", "multer", "num-traits", @@ -1158,6 +1161,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -1544,6 +1553,15 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "lru" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "matchers" version = "0.1.0" @@ -2978,15 +2996,13 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.5.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +checksum = "41515cc9e193536d93fd0dbbea0c73819c08eca76e0b30909a325c3ec90985bb" dependencies = [ "bitflags 2.6.0", "bytes", "http", - "http-body", - "http-body-util", "pin-project-lite", "tower-layer", "tower-service", diff --git a/crates/rtss_api/Cargo.toml b/crates/rtss_api/Cargo.toml index 681e483..4ffd1ef 100644 --- a/crates/rtss_api/Cargo.toml +++ b/crates/rtss_api/Cargo.toml @@ -12,8 +12,8 @@ chrono = { version = "0.4.38", features = ["serde"] } axum = "0.7.5" axum-extra = { version = "0.9.3", features = ["typed-header"] } # jsonwebtoken = "9.3.0" -tower-http = { version = "0.5.0", features = ["cors"] } -async-graphql = { version = "7.0.7", features = ["chrono"] } +tower-http = { version = "0.6.0", features = ["cors"] } +async-graphql = { version = "7.0.7", features = ["chrono", "dataloader"] } async-graphql-axum = "7.0.6" base64 = "0.22.1" sysinfo = "0.31.3" diff --git a/crates/rtss_api/src/apis/draft_data.rs b/crates/rtss_api/src/apis/draft_data.rs index 6835737..11092e7 100644 --- a/crates/rtss_api/src/apis/draft_data.rs +++ b/crates/rtss_api/src/apis/draft_data.rs @@ -112,7 +112,9 @@ impl DraftDataMutation { data: String, // base64编码的数据 ) -> async_graphql::Result { let db_accessor = ctx.data::()?; - let bytes = BASE64_STANDARD.decode(data).expect("base64 decode error"); + let bytes = BASE64_STANDARD + .decode(data) + .map_err(|e| async_graphql::Error::new(format!("base64 decode error: {}", e)))?; let draft_data = db_accessor.update_draft_data_data(id, &bytes).await?; Ok(draft_data.into()) } @@ -303,3 +305,17 @@ impl From for DraftIscsDataDto { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_base64() { + let data = b"hello world"; + let encoded = BASE64_STANDARD.encode(data); + let decoded = BASE64_STANDARD.decode(&encoded).unwrap(); + assert_eq!(data, decoded.as_slice()); + println!("encoded: {}, decoded: {:?}", encoded, decoded); + } +} diff --git a/crates/rtss_api/src/apis/mod.rs b/crates/rtss_api/src/apis/mod.rs index ec86a5b..23b4e19 100644 --- a/crates/rtss_api/src/apis/mod.rs +++ b/crates/rtss_api/src/apis/mod.rs @@ -1,8 +1,12 @@ -use async_graphql::MergedObject; +use async_graphql::dataloader::DataLoader; +use async_graphql::{EmptySubscription, MergedObject, Schema}; use async_graphql::{Enum, InputObject, OutputType, SimpleObject}; use draft_data::{DraftDataMutation, DraftDataQuery}; use release_data::{ReleaseDataMutation, ReleaseDataQuery}; +use crate::simulation_definition::MutexSimulationManager; +use crate::ServerConfig; + mod common; mod draft_data; mod release_data; @@ -70,3 +74,25 @@ impl> From> for PageDto ) } } + +pub struct RtssDbLoader { + pub(crate) db_accessor: rtss_db::RtssDbAccessor, +} + +impl RtssDbLoader { + pub fn new(db_accessor: rtss_db::RtssDbAccessor) -> Self { + Self { db_accessor } + } +} + +pub type RtssAppSchema = Schema; + +pub async fn new_schema(config: &ServerConfig) -> RtssAppSchema { + let dba = rtss_db::get_db_accessor(&config.database_url).await; + let loader = RtssDbLoader::new(dba.clone()); + Schema::build(Query::default(), Mutation::default(), EmptySubscription) + .data(dba) + .data(DataLoader::new(loader, tokio::spawn)) + .data(MutexSimulationManager::default()) + .finish() +} diff --git a/crates/rtss_api/src/apis/release_data.rs b/crates/rtss_api/src/apis/release_data.rs index 01e30d8..05089eb 100644 --- a/crates/rtss_api/src/apis/release_data.rs +++ b/crates/rtss_api/src/apis/release_data.rs @@ -1,7 +1,12 @@ -use async_graphql::{Context, InputObject, Object, SimpleObject}; +use std::collections::HashMap; +use std::sync::Arc; + +use async_graphql::dataloader::*; +use async_graphql::{ComplexObject, Context, InputObject, Object, SimpleObject}; use base64::prelude::*; use chrono::NaiveDateTime; use rtss_db::model::*; +use rtss_db::prelude::*; use rtss_db::{model::ReleaseDataModel, ReleaseDataAccessor, RtssDbAccessor}; use rtss_dto::common::DataType; use serde_json::Value; @@ -9,7 +14,7 @@ use serde_json::Value; use crate::apis::draft_data::DraftDataDto; use super::common::{DataOptions, IscsDataOptions}; -use super::{PageDto, PageQueryDto}; +use super::{PageDto, PageQueryDto, RtssDbLoader}; #[derive(Default)] pub struct ReleaseDataQuery; @@ -59,6 +64,17 @@ impl ReleaseDataQuery { Ok(model.into()) } + /// 是否已经存在相同name的发布数据 + async fn is_release_data_name_exists( + &self, + ctx: &Context<'_>, + name: String, + ) -> async_graphql::Result { + let db_accessor = ctx.data::()?; + let result = db_accessor.is_release_data_name_exist(&name).await?; + Ok(result) + } + /// 查询发布数据的版本 async fn release_data_version_paging( &self, @@ -162,12 +178,12 @@ impl ReleaseDataMutation { async fn create_draft_data_from_release_data_version( &self, ctx: &Context<'_>, - release_data_id: i32, version_id: i32, + user_id: i32, ) -> async_graphql::Result { let db_accessor = ctx.data::()?; let result = db_accessor - .create_draft_from_release_version(release_data_id, version_id) + .create_draft_from_release_version(version_id, user_id) .await?; Ok(result.into()) } @@ -215,6 +231,7 @@ impl From for rtss_db::ReleaseDataQuery { } #[derive(Debug, SimpleObject)] +#[graphql(complex)] pub struct ReleaseDataDto { pub id: i32, pub name: String, @@ -227,6 +244,53 @@ pub struct ReleaseDataDto { pub updated_at: NaiveDateTime, } +#[ComplexObject] +impl ReleaseDataDto { + async fn description(&self, ctx: &Context<'_>) -> async_graphql::Result> { + if let Some(version_id) = self.used_version_id { + let loader = ctx.data_unchecked::>(); + let description = loader + .load_one(ReleaseDataVersionId::new(version_id)) + .await?; + Ok(description) + } else { + Ok(None) + } + } +} + +#[derive(Clone, Copy, Hash, PartialEq, Eq)] +pub struct ReleaseDataVersionId { + pub id: i32, +} + +impl ReleaseDataVersionId { + pub fn new(id: i32) -> Self { + Self { id } + } +} + +impl Loader for RtssDbLoader { + type Value = String; + type Error = Arc; + + async fn load( + &self, + keys: &[ReleaseDataVersionId], + ) -> Result, Self::Error> { + let ids: Vec = keys.iter().map(|k| k.id).collect(); + let rows = self + .db_accessor + .query_release_data_version_descriptions(ids.as_slice()) + .await?; + let map: HashMap = rows + .into_iter() + .map(|r| (ReleaseDataVersionId { id: r.0 }, r.1)) + .collect(); + Ok(map) + } +} + #[derive(Debug, SimpleObject)] pub struct ReleaseIscsDataWithoutVersionDto { pub release_data: ReleaseDataDto, diff --git a/crates/rtss_api/src/server.rs b/crates/rtss_api/src/server.rs index 65c220d..9199104 100644 --- a/crates/rtss_api/src/server.rs +++ b/crates/rtss_api/src/server.rs @@ -1,5 +1,4 @@ use async_graphql::*; -use async_graphql::{EmptySubscription, Schema}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use axum::extract::State; use axum::http::HeaderMap; @@ -14,8 +13,7 @@ use rtss_log::tracing::{debug, info}; use tokio::net::TcpListener; use tower_http::cors::CorsLayer; -use crate::apis::{Mutation, Query}; -use crate::simulation_definition::MutexSimulationManager; +use crate::apis::RtssAppSchema; pub struct ServerConfig { pub database_url: String, @@ -36,7 +34,7 @@ impl ServerConfig { } pub async fn serve(config: ServerConfig) -> anyhow::Result<()> { - let schema = new_schema(&config).await; + let schema = crate::apis::new_schema(&config).await; let app = Router::new() .route("/", get(graphiql).post(graphql_handler)) @@ -59,7 +57,7 @@ pub async fn serve(config: ServerConfig) -> anyhow::Result<()> { } async fn graphql_handler( - State(schema): State, + State(schema): State, _headers: HeaderMap, req: GraphQLRequest, ) -> GraphQLResponse { @@ -80,13 +78,3 @@ async fn graphql_handler( async fn graphiql() -> impl IntoResponse { Html(playground_source(GraphQLPlaygroundConfig::new("/"))) } - -pub type SimulationSchema = Schema; - -pub async fn new_schema(config: &ServerConfig) -> SimulationSchema { - let dba = rtss_db::get_db_accessor(&config.database_url).await; - Schema::build(Query::default(), Mutation::default(), EmptySubscription) - .data(dba) - .data(MutexSimulationManager::default()) - .finish() -} diff --git a/crates/rtss_db/src/db_access/mod.rs b/crates/rtss_db/src/db_access/mod.rs index dd8a81d..85bc470 100644 --- a/crates/rtss_db/src/db_access/mod.rs +++ b/crates/rtss_db/src/db_access/mod.rs @@ -3,6 +3,7 @@ pub use draft_data::*; mod release_data; pub use release_data::*; +#[derive(Clone)] pub struct RtssDbAccessor { pool: sqlx::PgPool, } diff --git a/crates/rtss_db/src/db_access/release_data.rs b/crates/rtss_db/src/db_access/release_data.rs index 68eb809..d5f6e27 100644 --- a/crates/rtss_db/src/db_access/release_data.rs +++ b/crates/rtss_db/src/db_access/release_data.rs @@ -51,6 +51,11 @@ pub trait ReleaseDataAccessor { &self, version_id: i32, ) -> Result; + /// 根据id列表查询发布版本数据description + async fn query_release_data_version_descriptions( + &self, + version_ids: &[i32], + ) -> Result, DbAccessError>; /// 查询发布数据详情 async fn query_release_data_with_used_version( &self, @@ -445,6 +450,25 @@ impl ReleaseDataAccessor for RtssDbAccessor { Ok(rdv) } + async fn query_release_data_version_descriptions( + &self, + version_ids: &[i32], + ) -> Result, DbAccessError> { + // 查询发布数据版本 + let rdv_table = ReleaseDataVersionColumn::Table.name(); + let rdv_id = ReleaseDataVersionColumn::Id.name(); + let rdv_description = ReleaseDataVersionColumn::Description.name(); + let select_columns = format!("{rdv_id}, {rdv_description}"); + let rdv_query_clause = + format!("SELECT {select_columns} FROM {rdv_table} WHERE {rdv_id} = ANY($1)",); + let rdv = sqlx::query_as::<_, (i32, String)>(&rdv_query_clause) + .bind(version_ids) + .fetch_all(&self.pool) + .await?; + + Ok(rdv) + } + async fn query_release_data_with_used_version( &self, release_id: i32, @@ -780,6 +804,17 @@ mod tests { assert_eq!(page_result.total, 8); println!("分页查询发布数据测试成功"); + // 测试根据数据版本id查询descriptions + let version_ids: Vec = page_result + .data + .into_iter() + .map(|d| d.used_version_id.unwrap()) + .collect(); + let description_map = accessor + .query_release_data_version_descriptions(version_ids.as_slice()) + .await?; + println!("{:?}", description_map); + Ok(()) } } diff --git a/crates/rtss_db/src/error.rs b/crates/rtss_db/src/error.rs index 868e94e..4610ac2 100644 --- a/crates/rtss_db/src/error.rs +++ b/crates/rtss_db/src/error.rs @@ -4,7 +4,7 @@ use thiserror::Error; pub enum DbAccessError { #[error("未知的数据库访问错误")] Unknown, - #[error("sqlx 错误: {0}")] + #[error("数据访问错误: {0}")] SqlxError(#[from] sqlx::Error), #[error("数据已存在")] RowExist,