341 lines
10 KiB
Rust
341 lines
10 KiB
Rust
use std::io::{self, BufWriter, Write};
|
|
use std::{path::PathBuf, str::FromStr};
|
|
|
|
use crate::{result::*, ChannelMode};
|
|
|
|
#[derive(Debug, Default)]
|
|
pub struct Protocol {
|
|
#[allow(dead_code)]
|
|
version: Version,
|
|
}
|
|
|
|
impl From<Version> for Protocol {
|
|
fn from(version: Version) -> Self {
|
|
Self { version }
|
|
}
|
|
}
|
|
|
|
impl Protocol {
|
|
pub fn format_request(&self, req: Request) -> io::Result<Vec<u8>> {
|
|
let mut res = BufWriter::new(Vec::new());
|
|
|
|
match req {
|
|
Request::Quit => write!(res, "QUIT")?,
|
|
|
|
Request::Ping => write!(res, "PING")?,
|
|
|
|
Request::Start { mode, password } => write!(res, "START {} {}", mode, password)?,
|
|
|
|
#[rustfmt::skip]
|
|
Request::Count { collection, bucket, object } => match (bucket, object) {
|
|
(Some(b), Some(o)) => write!(res, "COUNT {} {} {}", collection, b, o)?,
|
|
(Some(b), None) => write!(res, "COUNT {} {}", collection, b)?,
|
|
(None, None) => write!(res, "COUNT {}", collection)?,
|
|
_ => panic!("Wrong protocol format"),
|
|
},
|
|
|
|
#[rustfmt::skip]
|
|
Request::Flush { collection, bucket, object } => match (bucket, object) {
|
|
(Some(b), Some(o)) => write!(res, "FLUSHO {} {} {}", collection, b, o)?,
|
|
(Some(b), None) => write!(res, "FLUSHB {} {}", collection, b)?,
|
|
(None, None) => write!(res, "FLUSHC {}", collection)?,
|
|
_ => panic!("Wrong protocol format"),
|
|
},
|
|
|
|
#[rustfmt::skip]
|
|
Request::Pop { collection, bucket, object, terms } => {
|
|
write!(res, "POP {} {} {} \"{}\"", collection, bucket, object, terms)?
|
|
},
|
|
#[rustfmt::skip]
|
|
Request::Push { collection, bucket, object, terms, lang } => {
|
|
let oneline_terms = remove_multiline(&terms);
|
|
write!(res, "PUSH {} {} {} \"{}\"", collection, bucket, object, oneline_terms)?;
|
|
if let Some(lang) = lang {
|
|
write!(res, " LANG({})", lang)?
|
|
}
|
|
}
|
|
|
|
#[rustfmt::skip]
|
|
Request::Query { collection, bucket, terms, offset, limit, lang } => {
|
|
write!(res, "QUERY {} {} \"{}\"", collection, bucket, terms)?;
|
|
if let Some(limit) = limit {
|
|
write!(res, " LIMIT({})", limit)?;
|
|
}
|
|
if let Some(offset) = offset {
|
|
write!(res, " OFFSET({})", offset)?;
|
|
}
|
|
if let Some(lang) = lang {
|
|
write!(res, " LANG({})", lang)?;
|
|
}
|
|
}
|
|
#[rustfmt::skip]
|
|
Request::Suggest { collection, bucket, word, limit } => {
|
|
write!(res, "SUGGEST {} {} \"{}\"", collection, bucket, word)?;
|
|
if let Some(limit) = limit {
|
|
write!(res, " LIMIT({})", limit)?;
|
|
}
|
|
}
|
|
|
|
#[rustfmt::skip]
|
|
Request::List { collection, bucket, limit, offset } => {
|
|
write!(res, "LIST {} {}", collection, bucket)?;
|
|
if let Some(limit) = limit {
|
|
write!(res, " LIMIT({})", limit)?;
|
|
}
|
|
if let Some(offset) = offset {
|
|
write!(res, " OFFSET({})", offset)?;
|
|
}
|
|
}
|
|
|
|
Request::Trigger(triger_req) => match triger_req {
|
|
TriggerRequest::Consolidate => write!(res, "TRIGGER consolidate")?,
|
|
TriggerRequest::Backup(path) => {
|
|
write!(res, "TRIGGER backup {}", path.to_str().unwrap())?
|
|
}
|
|
TriggerRequest::Restore(path) => {
|
|
write!(res, "TRIGGER restore {}", path.to_str().unwrap())?
|
|
}
|
|
},
|
|
}
|
|
|
|
write!(res, "\r\n")?;
|
|
res.flush()?;
|
|
|
|
Ok(res.into_inner()?)
|
|
}
|
|
|
|
pub fn parse_response(&self, line: &str) -> Result<Response> {
|
|
let mut segments = line.split_whitespace();
|
|
match segments.next() {
|
|
Some("STARTED") => match (segments.next(), segments.next(), segments.next()) {
|
|
(Some(_raw_mode), Some(raw_protocol), Some(raw_buffer_size)) => {
|
|
Ok(Response::Started(StartedPayload {
|
|
protocol_version: parse_server_config(raw_protocol)?,
|
|
max_buffer_size: parse_server_config(raw_buffer_size)?,
|
|
}))
|
|
}
|
|
_ => Err(Error::WrongResponse),
|
|
},
|
|
Some("PENDING") => {
|
|
let event_id = segments
|
|
.next()
|
|
.map(String::from)
|
|
.ok_or(Error::WrongResponse)?;
|
|
Ok(Response::Pending(event_id))
|
|
}
|
|
Some("RESULT") => match segments.next() {
|
|
Some(num) => num
|
|
.parse()
|
|
.map(Response::Result)
|
|
.map_err(|_| Error::WrongResponse),
|
|
_ => Err(Error::WrongResponse),
|
|
},
|
|
Some("EVENT") => {
|
|
let event_kind = match segments.next() {
|
|
Some("SUGGEST") => Ok(EventKind::Suggest),
|
|
Some("QUERY") => Ok(EventKind::Query),
|
|
Some("LIST") => Ok(EventKind::List),
|
|
_ => Err(Error::WrongResponse),
|
|
}?;
|
|
|
|
let event_id = segments
|
|
.next()
|
|
.map(String::from)
|
|
.ok_or(Error::WrongResponse)?;
|
|
|
|
let objects = segments.map(String::from).collect();
|
|
|
|
Ok(Response::Event(event_kind, event_id, objects))
|
|
}
|
|
Some("OK") => Ok(Response::Ok),
|
|
Some("ENDED") => Ok(Response::Ended),
|
|
Some("CONNECTED") => Ok(Response::Connected),
|
|
Some("ERR") => match segments.next() {
|
|
Some(message) => Err(Error::SonicServer(String::from(message))),
|
|
_ => Err(Error::WrongResponse),
|
|
},
|
|
_ => Err(Error::WrongResponse),
|
|
}
|
|
}
|
|
}
|
|
|
|
//===========================================================================//
|
|
// Primitives //
|
|
//===========================================================================//
|
|
|
|
#[derive(Debug, PartialEq, Eq)]
|
|
#[repr(u8)]
|
|
pub enum Version {
|
|
V1 = 1,
|
|
}
|
|
|
|
impl Default for Version {
|
|
fn default() -> Self {
|
|
Self::V1
|
|
}
|
|
}
|
|
|
|
impl TryFrom<u8> for Version {
|
|
type Error = ();
|
|
|
|
fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
|
|
match value {
|
|
1 => Ok(Self::V1),
|
|
_ => Err(()),
|
|
}
|
|
}
|
|
}
|
|
|
|
//===========================================================================//
|
|
// Response //
|
|
//===========================================================================//
|
|
|
|
pub type EventId = String;
|
|
|
|
#[derive(Debug)]
|
|
pub enum Response {
|
|
Ok,
|
|
Ended,
|
|
Connected,
|
|
Pending(EventId),
|
|
Pong,
|
|
Started(StartedPayload),
|
|
Result(usize),
|
|
Event(EventKind, EventId, Vec<String>),
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct StartedPayload {
|
|
pub protocol_version: u8,
|
|
pub max_buffer_size: usize,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum EventKind {
|
|
Suggest,
|
|
Query,
|
|
List,
|
|
}
|
|
|
|
//===========================================================================//
|
|
// Request //
|
|
//===========================================================================//
|
|
|
|
#[derive(Debug)]
|
|
pub enum Request {
|
|
Start {
|
|
mode: ChannelMode,
|
|
password: String,
|
|
},
|
|
Quit,
|
|
Ping,
|
|
Trigger(TriggerRequest),
|
|
Suggest {
|
|
collection: String,
|
|
bucket: String,
|
|
word: String,
|
|
limit: Option<usize>,
|
|
},
|
|
List {
|
|
collection: String,
|
|
bucket: String,
|
|
limit: Option<usize>,
|
|
offset: Option<usize>,
|
|
},
|
|
Query {
|
|
collection: String,
|
|
bucket: String,
|
|
terms: String,
|
|
offset: Option<usize>,
|
|
limit: Option<usize>,
|
|
lang: Option<&'static str>,
|
|
},
|
|
Push {
|
|
collection: String,
|
|
bucket: String,
|
|
object: String,
|
|
terms: String,
|
|
lang: Option<&'static str>,
|
|
},
|
|
Pop {
|
|
collection: String,
|
|
bucket: String,
|
|
object: String,
|
|
terms: String,
|
|
},
|
|
Flush {
|
|
collection: String,
|
|
bucket: Option<String>,
|
|
object: Option<String>,
|
|
},
|
|
Count {
|
|
collection: String,
|
|
bucket: Option<String>,
|
|
object: Option<String>,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum TriggerRequest {
|
|
Consolidate,
|
|
Backup(PathBuf),
|
|
Restore(PathBuf),
|
|
}
|
|
|
|
//===========================================================================//
|
|
// Utils //
|
|
//===========================================================================//
|
|
|
|
fn parse_server_config<T: FromStr>(raw: &str) -> Result<T> {
|
|
raw.split_terminator(&['(', ')'])
|
|
.nth(1)
|
|
.ok_or(Error::WrongResponse)?
|
|
.parse()
|
|
.map_err(|_| Error::WrongResponse)
|
|
}
|
|
|
|
fn remove_multiline(text: &str) -> String {
|
|
text.lines()
|
|
.enumerate()
|
|
.fold(String::new(), |mut acc, (i, line)| {
|
|
if i != 0 && !line.is_empty() && !acc.is_empty() && !acc.ends_with(' ') {
|
|
acc.push(' ');
|
|
}
|
|
|
|
acc.push_str(line);
|
|
acc
|
|
})
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn should_parse_protocol() {
|
|
match parse_server_config::<u8>("protocol(1)") {
|
|
Ok(protocol) => assert_eq!(protocol, 1),
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn should_parse_buffer_size() {
|
|
match parse_server_config::<usize>("buffer_size(20000)") {
|
|
Ok(buffer_size) => assert_eq!(buffer_size, 20000),
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn should_make_single_line() {
|
|
let text = "
|
|
Hello
|
|
World
|
|
";
|
|
|
|
let expected_text = "Hello World";
|
|
assert_eq!(remove_multiline(text), expected_text);
|
|
}
|
|
}
|