Compare commits

...

3 Commits

Author SHA1 Message Date
3ed0f5af1e
更新类型定义和回调函数签名***
***更新类型定义和回调函数签名
2024-02-24 23:07:31 +08:00
c366f6a735
就先这样吧( 2024-02-24 23:01:49 +08:00
db3905eec3
更新SendMessage类以支持链式调用 2024-02-24 19:26:21 +08:00
3 changed files with 166 additions and 38 deletions

View File

@ -1,6 +1,16 @@
# Python 兼容版本 3.8+
from typing import Optional
from typing import Optional, Callable
"""
pub type RoomId = i64;
pub type UserId = i64;
pub type MessageId = String;
"""
RoomId = int
UserId = int
MessageId = str
class IcaStatus:
@ -11,7 +21,7 @@ class IcaStatus:
def online(self) -> bool:
...
@property
def self_id(self) -> Optional[bool]:
def self_id(self) -> Optional[UserId]:
...
@property
def nick_name(self) -> Optional[str]:
@ -38,7 +48,18 @@ class ReplyMessage:
class SendMessage:
...
@property
def content(self) -> str:
...
@content.setter
def content(self, value: str) -> None:
...
def with_content(self, content: str) -> "SendMessage":
"""
为了链式调用, 返回自身
"""
self.content = content
return self
class NewMessage:
@ -47,10 +68,13 @@ class NewMessage:
def __str__(self) -> str:
...
@property
def id(self) -> MessageId:
...
@property
def content(self) -> str:
...
@property
def sender_id(self) -> int:
def sender_id(self) -> UserId:
...
@property
def is_from_self(self) -> bool:
@ -64,19 +88,26 @@ class IcaClient:
@staticmethod
async def send_message_a(client: "IcaClient", message: SendMessage) -> bool:
"""
仅作占位
仅作占位, 不能使用
(因为目前来说, rust调用 Python端没法启动一个异步运行时
所以只能 tokio::task::block_in_place 转换成同步调用)
"""
def send_message(self, message: SendMessage) -> bool:
...
def debug(self, message: str) -> None:
...
"""向日志中输出调试信息"""
def info(self, message: str) -> None:
...
"""向日志中输出信息"""
def warn(self, message: str) -> None:
...
"""向日志中输出警告信息"""
def on_message(msg: NewMessage, client: IcaClient) -> None:
...
on_load = Callable[[IcaClient], None]
# def on_load(client: IcaClient) -> None:
# ...
on_message = Callable[[NewMessage, IcaClient], None]
# def on_message(msg: NewMessage, client: IcaClient) -> None:
# ...
on_delete_message = Callable[[int, IcaClient], None]

View File

@ -5,6 +5,7 @@ use tracing::{debug, info, warn};
use crate::client::send_message;
use crate::data_struct::messages::{NewMessage, ReplyMessage, SendMessage};
use crate::data_struct::MessageId;
use crate::ClientStatus;
#[pyclass]
@ -126,7 +127,10 @@ impl NewMessagePy {
pub fn __str__(&self) -> String {
format!("{:?}", self.msg)
}
#[getter]
pub fn get_id(&self) -> MessageId {
self.msg.msg_id.clone()
}
#[getter]
pub fn get_content(&self) -> String {
self.msg.content.clone()
@ -182,6 +186,20 @@ impl SendMessagePy {
pub fn __str__(&self) -> String {
format!("{:?}", self.msg)
}
/// 设置消息内容
/// 用于链式调用
pub fn with_content(&mut self, content: String) -> Self {
self.msg.content = content;
self.clone()
}
#[getter]
pub fn get_content(&self) -> String {
self.msg.content.clone()
}
#[setter]
pub fn set_content(&mut self, content: String) {
self.msg.content = content;
}
}
impl SendMessagePy {

View File

@ -3,6 +3,7 @@ pub mod class;
use std::time::SystemTime;
use std::{collections::HashMap, path::PathBuf};
use futures_util::future::join_all;
use pyo3::prelude::*;
use rust_socketio::asynchronous::Client;
use tracing::{debug, info, warn};
@ -198,41 +199,119 @@ pub fn init_py(config: &IcaConfig) {
pub async fn new_message_py(message: &NewMessage, client: &Client) {
// 验证插件是否改变
verify_plugins();
let cwd = std::env::current_dir().unwrap();
let plugins = PyStatus::get_files();
for (path, (_, py_module)) in plugins.iter() {
// 切换工作目录到运行的插件的位置
let mut goto = cwd.clone();
goto.push(path.parent().unwrap());
if let Err(e) = std::env::set_current_dir(&goto) {
warn!("移动工作目录到 {:?} 失败 {:?} cwd: {:?}", goto, e, cwd);
}
// let tasks: Vec<_> = plugins.iter().map(|(path, (_, py_module))| {
// let msg = class::NewMessagePy::new(message);
// let client = class::IcaClientPy::new(client);
// let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
// let task = tokio::spawn(async move {
// tokio::select! {
// _ = tokio::spawn(async move {Python::with_gil(|py| {
// let args = (msg, client);
// let async_py_func = py_module.getattr(py, "on_message");
// match async_py_func {
// Ok(async_py_func) => match async_py_func.as_ref(py).call1(args) {
// Err(e) => {
// warn!("get a PyErr when call on_message from {:?}: {:?}", path, e);
// }
// _ => (),
// },
// Err(e) => {
// warn!("failed to get on_message function: {:?}", e);
// }
// }
// })}) => (),
// _ = cancel_rx => (),
// }
// });
// (task, cancel_tx)
// }).collect();
Python::with_gil(|py| {
let msg = class::NewMessagePy::new(message);
let client = class::IcaClientPy::new(client);
let args = (msg, client);
let async_py_func = py_module.getattr(py, "on_message");
match async_py_func {
Ok(async_py_func) => {
match async_py_func.as_ref(py).call1(args) {
// let timeout = tokio::time::sleep(std::time::Duration::from_secs(5));
// tokio::select! {
// _ = join_all(tasks.map(|(task, _)| task)) => (),
// _ = timeout => {
// warn!("timeout when join all tasks");
// for (_, cancel_tx) in &tasks {
// let _ = cancel_tx.send(());
// }
// }
// }
// for (path, (_, py_module)) in plugins.iter() {
// let msg = class::NewMessagePy::new(message);
// let client = class::IcaClientPy::new(client);
// let task = tokio::spawn(async move {
// Python::with_gil(|py| {
// let args = (msg, client);
// let async_py_func = py_module.getattr(py, "on_message");
// match async_py_func {
// Ok(async_py_func) => match async_py_func.as_ref(py).call1(args) {
// Err(e) => {
// warn!("get a PyErr when call on_message from {:?}: {:?}", path, e);
// }
// _ => (),
// },
// Err(e) => {
// warn!("failed to get on_message function: {:?}", e);
// }
// }
// })
// });
// tokio::select! {
// _ = task => (),
// _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {
// warn!("timeout when join all tasks");
// // task.abort();
// }
// }
// }
let mut tasks = Vec::with_capacity(plugins.len());
for (path, (_, py_module)) in plugins.iter() {
let msg = class::NewMessagePy::new(message);
let client = class::IcaClientPy::new(client);
let task = tokio::spawn(async move {
Python::with_gil(|py| {
let args = (msg, client);
let async_py_func = py_module.getattr(py, "on_message");
match async_py_func {
Ok(async_py_func) => match async_py_func.as_ref(py).call1(args) {
Err(e) => {
warn!("get a PyErr when call on_message from {:?}: {:?}", path, e);
},
_ => ()
}
_ => (),
},
Err(e) => {
warn!("failed to get on_message function: {:?}", e);
}
}
Err(e) => {
warn!("failed to get on_message function: {:?}", e);
}
}
})
});
tasks.push(task);
}
// 最后切换回来
if let Err(e) = std::env::set_current_dir(&cwd) {
warn!("设置工作目录{:?} 失败:{:?}", cwd, e);
// 等待所有的插件执行完毕
// 超时时间为 0.1 秒
// ~~ 超时则取消所有的任务 ~~
// 超时就超时了……, 就让他跑着了……
// 主要是, 这玩意是同步的 还没法取消
let wait_time = std::time::Duration::from_millis(100);
let awaits = join_all(tasks);
let timeout = tokio::time::sleep(wait_time.clone());
let await_task = tokio::time::timeout(wait_time.clone(), awaits);
tokio::select! {
_ = await_task => (),
_ = timeout => {
warn!("timeout when join all tasks");
// for task in tasks {
// task.abort();
// }
}
}
// match tokio::time::timeout(wait_time.clone(), awaits).await {
// Ok(_) => (),
// Err(e) => {
// warn!("timeout when join all tasks: {:?}", e);
// }
// }
}