Files
letro-authentication-service/crates/policy/src/lib.rs
2022-06-03 13:37:20 +02:00

203 lines
5.7 KiB
Rust

// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Cursor;
use anyhow::bail;
use oauth2_types::registration::ClientMetadata;
use opa_wasm::Runtime;
use serde::Deserialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};
use wasmtime::{Config, Engine, Module, Store};
const DEFAULT_POLICY: &[u8] = include_bytes!("../policies/policy.wasm");
pub fn default_wasm_policy() -> impl AsyncRead + std::marker::Unpin {
Cursor::new(DEFAULT_POLICY)
}
#[derive(Debug, Error)]
pub enum LoadError {
#[error("failed to read module")]
Read(#[from] tokio::io::Error),
#[error("failed to create WASM engine")]
Engine(#[source] anyhow::Error),
#[error("module compilation task crashed")]
CompilationTask(#[from] tokio::task::JoinError),
#[error("failed to compile WASM module")]
Compilation(#[source] anyhow::Error),
#[error("failed to instantiate a test instance")]
Instantiate(#[source] anyhow::Error),
}
pub struct PolicyFactory {
engine: Engine,
module: Module,
data: serde_json::Value,
login_entrypoint: String,
register_entrypoint: String,
client_registration_entrypoint: String,
}
impl PolicyFactory {
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: serde_json::Value,
login_entrypoint: String,
register_entrypoint: String,
client_registration_entrypoint: String,
) -> Result<Self, LoadError> {
let mut config = Config::default();
config.async_support(true);
config.cranelift_opt_level(wasmtime::OptLevel::Speed);
let engine = Engine::new(&config).map_err(LoadError::Engine)?;
// Read and compile the module
let mut buf = Vec::new();
source.read_to_end(&mut buf).await?;
// Compilation is CPU-bound, so spawn that in a blocking task
let (engine, module) = tokio::task::spawn_blocking(move || {
let module = Module::new(&engine, buf)?;
anyhow::Ok((engine, module))
})
.await?
.map_err(LoadError::Compilation)?;
let factory = Self {
engine,
module,
data,
login_entrypoint,
register_entrypoint,
client_registration_entrypoint,
};
// Try to instanciate
factory
.instantiate()
.await
.map_err(LoadError::Instantiate)?;
Ok(factory)
}
pub async fn instantiate(&self) -> Result<Policy, anyhow::Error> {
let mut store = Store::new(&self.engine, ());
let runtime = Runtime::new(&mut store, &self.module).await?;
// Check that we have the required entrypoints
let entrypoints = runtime.entrypoints();
for e in [
self.login_entrypoint.as_str(),
self.register_entrypoint.as_str(),
] {
if !entrypoints.contains(e) {
bail!("missing entrypoint {e}")
}
}
let instance = runtime.with_data(&mut store, &self.data).await?;
Ok(Policy {
store,
instance,
login_entrypoint: self.login_entrypoint.clone(),
register_entrypoint: self.register_entrypoint.clone(),
client_registration_entrypoint: self.client_registration_entrypoint.clone(),
})
}
}
#[derive(Deserialize)]
struct EvaluationResult {
result: bool,
}
#[derive(Debug)]
pub struct Policy {
store: Store<()>,
instance: opa_wasm::Policy,
login_entrypoint: String,
register_entrypoint: String,
client_registration_entrypoint: String,
}
impl Policy {
#[tracing::instrument]
pub async fn evaluate_login(
&mut self,
user: &mas_data_model::User<()>,
) -> Result<bool, anyhow::Error> {
let user = serde_json::to_value(user)?;
let input = serde_json::json!({ "user": user });
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.login_entrypoint, &input)
.await?;
Ok(res.result)
}
#[tracing::instrument]
pub async fn evaluate_register(
&mut self,
username: &str,
email: &str,
) -> Result<bool, anyhow::Error> {
let input = serde_json::json!({
"user": {
"username": username,
"email": email
}
});
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.register_entrypoint, &input)
.await?;
Ok(res.result)
}
#[tracing::instrument]
pub async fn evaluate_client_registration(
&mut self,
client_metadata: &ClientMetadata,
) -> Result<bool, anyhow::Error> {
let client_metadata = serde_json::to_value(client_metadata)?;
let input = serde_json::json!({
"client_metadata": client_metadata,
});
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(
&mut self.store,
&self.client_registration_entrypoint,
&input,
)
.await?;
Ok(res.result)
}
}