1
0
Fork 0

Did I tell you I was paranoid

This commit is contained in:
Honbra 2024-04-16 22:07:00 +02:00
parent 4de6254f08
commit 0ec4d86221
Signed by: honbra
GPG key ID: B61CC9ADABE2D952
23 changed files with 388 additions and 299 deletions

View file

@ -1,50 +1,51 @@
use std::{path::PathBuf, sync::Arc};
use std::path::PathBuf;
use axum::{body::Body, extract::State, routing::post, Json, Router};
use axum::{
body::Body,
extract::{Path, State},
Json,
};
use axum_extra::{routing::Resource, TypedHeader};
use futures_util::TryStreamExt;
use headers::ContentType;
use mime::Mime;
use serde::Serialize;
use sha2::{Digest, Sha256};
use sqlx::{query, PgPool};
use tokio::{
fs::{self, File},
io,
};
use sqlx::query;
use tokio::{fs, io};
use tokio_util::io::StreamReader;
use tracing::{error, field, info, instrument};
use ulid::Ulid;
use uuid::Uuid;
use crate::{config::Config, error::AppError};
use crate::{app::SharedState, error::AppError};
#[derive(Clone)]
struct SharedState {
db: PgPool,
config: Arc<Config>,
pub fn resource() -> Resource<SharedState> {
Resource::named("files")
.create(upload_file)
.show(get_file_info)
}
pub fn router(db: PgPool, config: Arc<Config>) -> Router {
Router::new()
.route("/", post(upload_file))
.with_state(SharedState { db, config })
}
#[derive(Debug, Serialize)]
struct UploadedFile {
key: Ulid,
#[derive(Serialize)]
struct File {
id: Ulid,
hash: String,
mime: String,
keys: Vec<Ulid>,
}
#[instrument(skip(db, body))]
async fn upload_file(
State(SharedState { db, config }): State<SharedState>,
TypedHeader(content_type): TypedHeader<ContentType>,
body: Body,
) -> Result<Json<UploadedFile>, AppError> {
let id_temp = Ulid::new();
let file_path_temp = PathBuf::from("temp").join(id_temp.to_string());
) -> Result<Json<File>, AppError> {
let id = Ulid::new();
let path_temp = config.file_temp_dir.join(id.to_string());
let mut hasher = Sha256::new();
{
let mut file_temp = File::create(&file_path_temp).await?;
let mut file_temp = fs::File::create(&path_temp).await?;
let better_body = body
.into_data_stream()
@ -55,12 +56,12 @@ async fn upload_file(
if let Err(err) = io::copy(&mut reader, &mut file_temp).await {
error!(
err = field::display(&err),
file_path = field::debug(&file_path_temp),
file_path = field::debug(&path_temp),
"failed to copy file, removing",
);
drop(file_temp);
if let Err(err) = fs::remove_file(file_path_temp).await {
if let Err(err) = fs::remove_file(path_temp).await {
error!(
err = field::display(err),
"failed to remove failed upload file",
@ -73,16 +74,16 @@ async fn upload_file(
let hash = hasher.finalize();
let hash_hex = hex::encode(hash);
let file_path_hash = PathBuf::from("files").join(&hash_hex);
let path_hash = PathBuf::from("files").join(&hash_hex);
if fs::try_exists(&file_path_hash).await? {
if fs::try_exists(&path_hash).await? {
info!(hash = hash_hex, "file already exists");
if let Err(err) = fs::remove_file(&file_path_temp).await {
if let Err(err) = fs::remove_file(&path_temp).await {
error!(err = field::display(&err), "failed to remove temp file");
}
} else if let Err(err) = fs::rename(&file_path_temp, &file_path_hash).await {
} else if let Err(err) = fs::rename(&path_temp, &path_hash).await {
error!(err = field::display(&err), "failed to move finished file");
if let Err(err) = fs::remove_file(&file_path_temp).await {
if let Err(err) = fs::remove_file(&path_temp).await {
error!(
err = field::display(&err),
"failed to remove file after failed move",
@ -91,27 +92,64 @@ async fn upload_file(
return Err(err.into());
}
let key = Ulid::new();
query!(
"INSERT INTO file (hash, mime) VALUES ($1, $2) ON CONFLICT DO NOTHING",
&hash[..],
"video/mp4", // I was testing with a video lol
)
.execute(&db)
.await?;
let result = query!(
"INSERT INTO file_key (id, file_hash) VALUES ($1, $2)",
Uuid::from(key),
&hash[..],
)
.execute(&db)
.await?;
let mime = Into::<Mime>::into(content_type);
let mime_str = mime.to_string();
match result.rows_affected() {
1 => Ok(Json(UploadedFile {
key,
match query!(
"INSERT INTO file (id, hash, mime) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING",
Uuid::from(id),
&hash[..],
mime_str,
)
.execute(&db)
.await?
.rows_affected()
{
0 | 1 => {}
rows => return Err(AppError::ImpossibleAffectedRows(rows)),
}
let key = Ulid::new();
match query!(
"INSERT INTO file_key (id, file_id) VALUES ($1, $2)",
Uuid::from(key),
Uuid::from(id),
)
.execute(&db)
.await?
.rows_affected()
{
1 => Ok(Json(File {
id,
hash: hash_hex,
mime: mime_str,
keys: vec![key],
})),
rows => Err(AppError::ImpossibleAffectedRows(rows)),
}
}
async fn get_file_info(
State(SharedState { db, .. }): State<SharedState>,
Path(id): Path<Ulid>,
) -> Result<Json<File>, AppError> {
let (file, keys) = tokio::try_join!(
query!(
"SELECT id, hash, mime FROM file WHERE id = $1",
Uuid::from(id),
)
.fetch_optional(&db),
query!("SELECT id FROM file_key WHERE file_id = $1", Uuid::from(id)).fetch_all(&db),
)?;
match file {
Some(r) => Ok(Json(File {
id,
hash: hex::encode(r.hash),
mime: r.mime,
keys: keys.into_iter().map(|r| r.id.into()).collect(),
})),
None => Err(AppError::FileNotFoundId(id)),
}
}

View file

@ -1,25 +1,23 @@
use axum::{
extract::{Path, State},
Json, Router,
Json,
};
use axum_extra::routing::Resource;
use http::StatusCode;
use serde::{Deserialize, Serialize};
use sqlx::{query, PgPool};
use sqlx::query;
use ulid::Ulid;
use url::Url;
use uuid::Uuid;
use crate::error::AppError;
use crate::{app::SharedState, error::AppError};
pub fn router(db: PgPool) -> Router {
let links = Resource::named("links")
pub fn resource() -> Resource<SharedState> {
Resource::named("links")
.create(create_link)
.show(get_link_info)
.update(update_link)
.destroy(delete_link);
Router::new().merge(links).with_state(db)
.destroy(delete_link)
}
#[derive(Serialize)]
@ -30,21 +28,23 @@ struct Link {
}
async fn get_link_info(
State(db): State<PgPool>,
State(SharedState { db, .. }): State<SharedState>,
Path(id): Path<Ulid>,
) -> Result<Json<Link>, AppError> {
let link = query!(
match query!(
"SELECT id, slug, destination FROM link WHERE id = $1",
Uuid::from(id),
)
.fetch_one(&db)
.await?;
Ok(Json(Link {
id: Ulid::from(link.id),
slug: link.slug,
destination: link.destination,
}))
.fetch_optional(&db)
.await?
{
Some(r) => Ok(Json(Link {
id: Ulid::from(r.id),
slug: r.slug,
destination: r.destination,
})),
None => Err(AppError::LinkNotFoundId(id)),
}
}
#[derive(Deserialize)]
@ -54,27 +54,27 @@ struct CreateLinkRequestBody {
}
async fn create_link(
State(db): State<PgPool>,
State(SharedState { db, .. }): State<SharedState>,
Json(CreateLinkRequestBody { slug, destination }): Json<CreateLinkRequestBody>,
) -> Result<Json<Link>, AppError> {
let id = Ulid::new();
let result = query!(
match query!(
"INSERT INTO link (id, slug, destination) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING",
Uuid::from(id),
slug,
destination.to_string(),
)
.execute(&db)
.await?;
match result.rows_affected() {
.await?
.rows_affected()
{
1 => Ok(Json(Link {
id,
slug,
destination: destination.to_string(),
})),
0 => Err(AppError::ApiLinkExists(id)),
0 => Err(AppError::LinkExists(id)),
rows => Err(AppError::ImpossibleAffectedRows(rows)),
}
}
@ -85,36 +85,36 @@ struct UpdateLinkRequestBody {
}
async fn update_link(
State(db): State<PgPool>,
State(SharedState { db, .. }): State<SharedState>,
Path(id): Path<Ulid>,
Json(UpdateLinkRequestBody { destination }): Json<UpdateLinkRequestBody>,
) -> Result<StatusCode, AppError> {
let result = query!(
match query!(
"UPDATE link SET destination = $2 WHERE id = $1",
Uuid::from(id),
destination.to_string(),
)
.execute(&db)
.await?;
match result.rows_affected() {
.await?
.rows_affected()
{
1 => Ok(StatusCode::NO_CONTENT),
0 => Err(AppError::ApiLinkNotFound(id)),
0 => Err(AppError::LinkNotFoundId(id)),
rows => Err(AppError::ImpossibleAffectedRows(rows)),
}
}
async fn delete_link(
State(db): State<PgPool>,
State(SharedState { db, .. }): State<SharedState>,
Path(id): Path<Ulid>,
) -> Result<StatusCode, AppError> {
let result = query!("DELETE FROM link WHERE id = $1", Uuid::from(id))
match query!("DELETE FROM link WHERE id = $1", Uuid::from(id))
.execute(&db)
.await?;
match result.rows_affected() {
.await?
.rows_affected()
{
1 => Ok(StatusCode::NO_CONTENT),
0 => Err(AppError::ApiLinkNotFound(id)),
0 => Err(AppError::LinkNotFoundId(id)),
rows => Err(AppError::ImpossibleAffectedRows(rows)),
}
}

View file

@ -1,15 +1,12 @@
mod files;
mod links;
use std::sync::Arc;
use axum::Router;
use sqlx::PgPool;
use crate::config::Config;
use super::SharedState;
pub fn router(db: PgPool, config: Arc<Config>) -> Router {
pub fn router() -> Router<SharedState> {
Router::new()
.nest("/files", files::router(db.clone(), config))
.nest("/links", links::router(db))
.merge(files::resource())
.merge(links::resource())
}

View file

@ -11,6 +11,12 @@ use tracing::{field, span, Level};
use crate::config::Config;
#[derive(Clone)]
struct SharedState {
db: PgPool,
config: Arc<Config>,
}
pub async fn build_app(config: Config) -> eyre::Result<Router> {
let db = PgPool::connect_with(
PgConnectOptions::new()
@ -21,20 +27,22 @@ pub async fn build_app(config: Config) -> eyre::Result<Router> {
)
.await?;
let config = Arc::new(config);
Ok(root::router(db.clone(), config.clone())
.nest("/api", api::router(db, config))
Ok(root::router()
.nest("/api", api::router())
.with_state(SharedState {
db,
config: Arc::new(config),
})
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<Body>| {
span!(
Level::INFO,
Level::DEBUG,
"http-request",
uri = field::display(request.uri()),
)
})
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
.on_response(DefaultOnResponse::new().level(Level::INFO)),
.on_request(DefaultOnRequest::new())
.on_response(DefaultOnResponse::new()),
))
}

View file

@ -1,5 +1,3 @@
use std::sync::Arc;
use axum::{
body::Body,
extract::{Path, State},
@ -11,83 +9,52 @@ use bytes::Bytes;
use http::{Request, Response};
use http_body_util::{combinators::UnsyncBoxBody, BodyExt};
use mime::Mime;
use sqlx::{query, PgPool};
use sqlx::query;
use tower_http::services::ServeFile;
use tracing::{error, field, instrument};
use ulid::Ulid;
use uuid::Uuid;
use crate::{config::Config, error::AppError};
use super::SharedState;
use crate::error::AppError;
#[derive(Clone)]
struct SharedState {
db: PgPool,
config: Arc<Config>,
}
pub fn router(db: PgPool, config: Arc<Config>) -> Router {
pub fn router() -> Router<SharedState> {
Router::new()
.route("/:slug", get(redirect_link))
.route("/f/:key", get(redirect_file))
.with_state(SharedState { db, config })
.route("/f/:key", get(download_file))
}
async fn redirect_link(
State(SharedState { db, .. }): State<SharedState>,
Path(slug): Path<String>,
) -> Result<Redirect, AppError> {
let result = query!("SELECT id, destination FROM link WHERE slug = $1", slug)
match query!("SELECT destination FROM link WHERE slug = $1", slug)
.fetch_optional(&db)
.await?
.map(|r| (Ulid::from(r.id), r.destination));
match result {
Some((id, destination)) => {
tokio::spawn(increase_visit_count(id, db));
Ok(Redirect::temporary(&destination))
}
None => Err(AppError::LinkNotFound(slug)),
{
Some(r) => Ok(Redirect::temporary(&r.destination)),
None => Err(AppError::LinkNotFoundSlug(slug)),
}
}
#[instrument(skip(db))]
async fn increase_visit_count(id: Ulid, db: PgPool) {
let result = query!(
"UPDATE link SET visit_count = visit_count + 1 WHERE id = $1",
Uuid::from(id),
)
.execute(&db)
.await;
match result {
Ok(result) if result.rows_affected() != 1 => {
error!(err = field::display(AppError::ImpossibleAffectedRows(result.rows_affected())));
}
Err(err) => error!(err = field::display(err)),
_ => {}
}
}
async fn redirect_file(
async fn download_file(
State(SharedState { db, config }): State<SharedState>,
Path(key): Path<Ulid>,
request: Request<Body>,
) -> Result<Response<UnsyncBoxBody<Bytes, BoxError>>, AppError> {
let result = query!(
"SELECT file_hash, mime FROM file_key JOIN file ON file_hash = hash WHERE id = $1",
Uuid::from(key)
match query!(
"SELECT hash, mime FROM file_key JOIN file ON file_id = file.id WHERE file_key.id = $1",
Uuid::from(key),
)
.fetch_optional(&db)
.await?
.map(|r| (r.file_hash, r.mime));
match result {
Some((file_hash, mime)) => {
let mime: Option<Mime> = mime.map_or(None, |m| m.parse().ok());
let file_path = config.file_store_dir.join(hex::encode(file_hash));
.map(|r| (r.hash, r.mime))
{
Some((hash, mime)) => {
let mime: Option<Mime> = mime.parse().ok();
let path = config.file_store_dir.join(hex::encode(hash));
let mut sf = match mime {
Some(mime) => ServeFile::new_with_mime(file_path, &mime),
None => ServeFile::new(file_path),
Some(mime) => ServeFile::new_with_mime(path, &mime),
None => ServeFile::new(path),
};
match sf.try_call(request).await {
Ok(response) => Ok(response.map(|body| body.map_err(Into::into).boxed_unsync())),

View file

@ -1,3 +1,5 @@
use std::path::PathBuf;
use axum::{body::Body, response::IntoResponse};
use http::StatusCode;
use tracing::{error, field};
@ -6,13 +8,17 @@ use ulid::Ulid;
#[derive(Debug, thiserror::Error)]
pub enum AppError {
#[error("link already exists ({0})")]
ApiLinkExists(Ulid),
LinkExists(Ulid),
#[error("link not found ({0})")]
ApiLinkNotFound(Ulid),
LinkNotFoundId(Ulid),
#[error("link not found ({0})")]
LinkNotFound(String),
LinkNotFoundSlug(String),
#[error("file not found ({0})")]
FileNotFoundId(Ulid),
#[error("file key not found ({0})")]
FileKeyNotFound(Ulid),
#[error("file is missing ({0})")]
FileMissing(PathBuf),
#[error("database returned an impossible number of affected rows ({0})")]
ImpossibleAffectedRows(u64),
#[error("database error")]
@ -27,11 +33,13 @@ impl IntoResponse for AppError {
fn into_response(self) -> axum::http::Response<Body> {
error!(err = field::display(&self));
match self {
Self::ApiLinkExists(_) => (StatusCode::BAD_REQUEST, "Link already exists"),
Self::ApiLinkNotFound(_) | Self::LinkNotFound(_) => {
Self::LinkExists(_) => (StatusCode::BAD_REQUEST, "Link already exists"),
Self::LinkNotFoundId(_) | Self::LinkNotFoundSlug(_) => {
(StatusCode::NOT_FOUND, "Link not found")
}
Self::FileNotFoundId(_) => (StatusCode::NOT_FOUND, "File not found"),
Self::FileKeyNotFound(_) => (StatusCode::NOT_FOUND, "File key not found"),
Self::FileMissing(_) => (StatusCode::INTERNAL_SERVER_ERROR, "File is missing"),
Self::ImpossibleAffectedRows(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Database returned an impossible number of affected rows",

View file

@ -36,20 +36,20 @@ fn main() -> eyre::Result<()> {
.extract()
.context("failed to parse config")?;
let rt = Runtime::new().context("failed to create tokio runtime")?;
Runtime::new()
.context("failed to create tokio runtime")?
.block_on(async move {
let listen_addr = config.listen_addr;
rt.block_on(async move {
let listen_addr = config.listen_addr;
let app = build_app(config).await.context("failed to build app")?;
let listener = TcpListener::bind(&listen_addr)
.await
.context("failed to bind listener")?;
let app = build_app(config).await.context("failed to build app")?;
let listener = TcpListener::bind(&listen_addr)
.await
.context("failed to bind listener")?;
axum::serve(listener, app)
.await
.context("server encountered a runtime error")?;
axum::serve(listener, app)
.await
.context("server encountered a runtime error")?;
Ok(())
})
Ok(())
})
}