diff --git a/ica-rs/src/py/mod.rs b/ica-rs/src/py/mod.rs index 0441a38..f955f1a 100644 --- a/ica-rs/src/py/mod.rs +++ b/ica-rs/src/py/mod.rs @@ -3,6 +3,7 @@ pub mod class; use std::time::SystemTime; use std::{collections::HashMap, path::PathBuf}; +use futures_util::future::join_all; use pyo3::prelude::*; use rust_socketio::asynchronous::Client; use tracing::{debug, info, warn}; @@ -198,41 +199,119 @@ pub fn init_py(config: &IcaConfig) { pub async fn new_message_py(message: &NewMessage, client: &Client) { // 验证插件是否改变 verify_plugins(); - let cwd = std::env::current_dir().unwrap(); let plugins = PyStatus::get_files(); - for (path, (_, py_module)) in plugins.iter() { - // 切换工作目录到运行的插件的位置 - let mut goto = cwd.clone(); - goto.push(path.parent().unwrap()); - - if let Err(e) = std::env::set_current_dir(&goto) { - warn!("移动工作目录到 {:?} 失败 {:?} cwd: {:?}", goto, e, cwd); - } + // let tasks: Vec<_> = plugins.iter().map(|(path, (_, py_module))| { + // let msg = class::NewMessagePy::new(message); + // let client = class::IcaClientPy::new(client); + // let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel(); + // let task = tokio::spawn(async move { + // tokio::select! { + // _ = tokio::spawn(async move {Python::with_gil(|py| { + // let args = (msg, client); + // let async_py_func = py_module.getattr(py, "on_message"); + // match async_py_func { + // Ok(async_py_func) => match async_py_func.as_ref(py).call1(args) { + // Err(e) => { + // warn!("get a PyErr when call on_message from {:?}: {:?}", path, e); + // } + // _ => (), + // }, + // Err(e) => { + // warn!("failed to get on_message function: {:?}", e); + // } + // } + // })}) => (), + // _ = cancel_rx => (), + // } + // }); + // (task, cancel_tx) + // }).collect(); - Python::with_gil(|py| { - let msg = class::NewMessagePy::new(message); - let client = class::IcaClientPy::new(client); - let args = (msg, client); - let async_py_func = py_module.getattr(py, "on_message"); - match async_py_func { - Ok(async_py_func) => { - match async_py_func.as_ref(py).call1(args) { + // let timeout = tokio::time::sleep(std::time::Duration::from_secs(5)); + // tokio::select! { + // _ = join_all(tasks.map(|(task, _)| task)) => (), + // _ = timeout => { + // warn!("timeout when join all tasks"); + // for (_, cancel_tx) in &tasks { + // let _ = cancel_tx.send(()); + // } + // } + // } + // for (path, (_, py_module)) in plugins.iter() { + // let msg = class::NewMessagePy::new(message); + // let client = class::IcaClientPy::new(client); + // let task = tokio::spawn(async move { + // Python::with_gil(|py| { + // let args = (msg, client); + // let async_py_func = py_module.getattr(py, "on_message"); + // match async_py_func { + // Ok(async_py_func) => match async_py_func.as_ref(py).call1(args) { + // Err(e) => { + // warn!("get a PyErr when call on_message from {:?}: {:?}", path, e); + // } + // _ => (), + // }, + // Err(e) => { + // warn!("failed to get on_message function: {:?}", e); + // } + // } + // }) + // }); + // tokio::select! { + // _ = task => (), + // _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + // warn!("timeout when join all tasks"); + // // task.abort(); + // } + + // } + // } + let mut tasks = Vec::with_capacity(plugins.len()); + for (path, (_, py_module)) in plugins.iter() { + let msg = class::NewMessagePy::new(message); + let client = class::IcaClientPy::new(client); + let task = tokio::spawn(async move { + Python::with_gil(|py| { + let args = (msg, client); + let async_py_func = py_module.getattr(py, "on_message"); + match async_py_func { + Ok(async_py_func) => match async_py_func.as_ref(py).call1(args) { Err(e) => { warn!("get a PyErr when call on_message from {:?}: {:?}", path, e); - }, - _ => () + } + _ => (), + }, + Err(e) => { + warn!("failed to get on_message function: {:?}", e); } } - Err(e) => { - warn!("failed to get on_message function: {:?}", e); - } - } + }) }); + tasks.push(task); } - - // 最后切换回来 - if let Err(e) = std::env::set_current_dir(&cwd) { - warn!("设置工作目录{:?} 失败:{:?}", cwd, e); + // 等待所有的插件执行完毕 + // 超时时间为 0.1 秒 + // ~~ 超时则取消所有的任务 ~~ + // 超时就超时了……, 就让他跑着了…… + // 主要是, 这玩意是同步的 还没法取消 + let wait_time = std::time::Duration::from_millis(100); + let awaits = join_all(tasks); + let timeout = tokio::time::sleep(wait_time.clone()); + let await_task = tokio::time::timeout(wait_time.clone(), awaits); + tokio::select! { + _ = await_task => (), + _ = timeout => { + warn!("timeout when join all tasks"); + // for task in tasks { + // task.abort(); + // } + } } + // match tokio::time::timeout(wait_time.clone(), awaits).await { + // Ok(_) => (), + // Err(e) => { + // warn!("timeout when join all tasks: {:?}", e); + // } + // } }