diff options
author | Timo Kösters <timo@koesters.xyz> | 2022-12-21 15:51:47 +0100 |
---|---|---|
committer | Timo Kösters <timo@koesters.xyz> | 2022-12-21 15:51:47 +0100 |
commit | 53f14a2c4c216b529cc63137d8704573197aed19 (patch) | |
tree | 866d747ed9345f04e7f31c10f7f4e045af29be2b | |
parent | df16012661b0ec958be41f5a52b4f9bb931dc3ff (diff) | |
parent | d20f21ae32061ef81feb9786ca78fbb1345c9a3a (diff) | |
download | conduit-53f14a2c4c216b529cc63137d8704573197aed19.zip |
Merge remote-tracking branch 'origin/next'
174 files changed, 20057 insertions, 15799 deletions
diff --git a/.dockerignore b/.dockerignore index 933b380..c78ddba 100644 --- a/.dockerignore +++ b/.dockerignore @@ -25,4 +25,4 @@ docker-compose* rustfmt.toml # Documentation -*.md +#*.md @@ -0,0 +1 @@ +use flake @@ -31,7 +31,6 @@ modules.xml ### vscode ### .vscode/* -!.vscode/settings.json !.vscode/tasks.json !.vscode/launch.json !.vscode/extensions.json @@ -62,3 +61,9 @@ conduit.db # Etc. **/*.rs.bk + +# Nix artifacts +/result* + +# Direnv cache +/.direnv diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index eb7a96f..91258ea 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -5,140 +5,10 @@ stages: - upload artifacts variables: + # Make GitLab CI go fast: GIT_SUBMODULE_STRATEGY: recursive FF_USE_FASTZIP: 1 CACHE_COMPRESSION_LEVEL: fastest - # Docker in Docker - DOCKER_HOST: tcp://docker:2375/ - DOCKER_TLS_CERTDIR: "" - DOCKER_DRIVER: overlay2 - -# --------------------------------------------------------------------- # -# Cargo: Compiling for different architectures # -# --------------------------------------------------------------------- # - -.build-cargo-shared-settings: - stage: "build" - needs: [] - rules: - - if: '$CI_COMMIT_BRANCH == "master"' - - if: '$CI_COMMIT_BRANCH == "next"' - - if: "$CI_COMMIT_TAG" - - if: '($CI_MERGE_REQUEST_APPROVED == "true") || $BUILD_EVERYTHING' # Once MR is approved, test all builds. Or if BUILD_EVERYTHING is set. - interruptible: true - image: "registry.gitlab.com/jfowl/conduit-containers/rust-with-tools@sha256:69ab327974aef4cc0daf4273579253bf7ae5e379a6c52729b83137e4caa9d093" - tags: ["docker"] - services: ["docker:dind"] - variables: - SHARED_PATH: $CI_PROJECT_DIR/shared - CARGO_PROFILE_RELEASE_LTO: "true" - CARGO_PROFILE_RELEASE_CODEGEN_UNITS: "1" - CARGO_INCREMENTAL: "false" # https://matklad.github.io/2021/09/04/fast-rust-builds.html#ci-workflow - before_script: - - 'echo "Building for target $TARGET"' - - "rustup show && rustc --version && cargo --version" # Print version info for debugging - # fix cargo and rustup mounts from this container (https://gitlab.com/gitlab-org/gitlab-foss/-/issues/41227) - - "mkdir -p $SHARED_PATH/cargo" - - "cp -r $CARGO_HOME/bin $SHARED_PATH/cargo" - - "cp -r $RUSTUP_HOME $SHARED_PATH" - - "export CARGO_HOME=$SHARED_PATH/cargo RUSTUP_HOME=$SHARED_PATH/rustup" - # If provided, bring in caching through sccache, which uses an external S3 endpoint to store compilation results. - - if [ -n "${SCCACHE_ENDPOINT}" ]; then export RUSTC_WRAPPER=/sccache; fi - script: - # cross-compile conduit for target - - 'time cross build --target="$TARGET" --locked --release' - - 'mv "target/$TARGET/release/conduit" "conduit-$TARGET"' - # print information about linking for debugging - - "file conduit-$TARGET" # print file information - - 'readelf --dynamic conduit-$TARGET | sed -e "/NEEDED/q1"' # ensure statically linked - cache: - # https://doc.rust-lang.org/cargo/guide/cargo-home.html#caching-the-cargo-home-in-ci - key: "cargo-cache-$TARGET" - paths: - - $SHARED_PATH/cargo/registry/index - - $SHARED_PATH/cargo/registry/cache - - $SHARED_PATH/cargo/git/db - artifacts: - expire_in: never - -build:release:cargo:x86_64-unknown-linux-musl-with-debug: - extends: .build-cargo-shared-settings - variables: - CARGO_PROFILE_RELEASE_DEBUG: 2 # Enable debug info for flamegraph profiling - TARGET: "x86_64-unknown-linux-musl" - after_script: - - "mv ./conduit-x86_64-unknown-linux-musl ./conduit-x86_64-unknown-linux-musl-with-debug" - artifacts: - name: "conduit-x86_64-unknown-linux-musl-with-debug" - paths: - - "conduit-x86_64-unknown-linux-musl-with-debug" - expose_as: "Conduit for x86_64-unknown-linux-musl-with-debug" - -build:release:cargo:x86_64-unknown-linux-musl: - extends: .build-cargo-shared-settings - variables: - TARGET: "x86_64-unknown-linux-musl" - artifacts: - name: "conduit-x86_64-unknown-linux-musl" - paths: - - "conduit-x86_64-unknown-linux-musl" - expose_as: "Conduit for x86_64-unknown-linux-musl" - -build:release:cargo:arm-unknown-linux-musleabihf: - extends: .build-cargo-shared-settings - variables: - TARGET: "arm-unknown-linux-musleabihf" - artifacts: - name: "conduit-arm-unknown-linux-musleabihf" - paths: - - "conduit-arm-unknown-linux-musleabihf" - expose_as: "Conduit for arm-unknown-linux-musleabihf" - -build:release:cargo:armv7-unknown-linux-musleabihf: - extends: .build-cargo-shared-settings - variables: - TARGET: "armv7-unknown-linux-musleabihf" - artifacts: - name: "conduit-armv7-unknown-linux-musleabihf" - paths: - - "conduit-armv7-unknown-linux-musleabihf" - expose_as: "Conduit for armv7-unknown-linux-musleabihf" - -build:release:cargo:aarch64-unknown-linux-musl: - extends: .build-cargo-shared-settings - variables: - TARGET: "aarch64-unknown-linux-musl" - artifacts: - name: "conduit-aarch64-unknown-linux-musl" - paths: - - "conduit-aarch64-unknown-linux-musl" - expose_as: "Conduit for aarch64-unknown-linux-musl" - -.cargo-debug-shared-settings: - extends: ".build-cargo-shared-settings" - rules: - - when: "always" - cache: - key: "build_cache--$TARGET--$CI_COMMIT_BRANCH--debug" - script: - # cross-compile conduit for target - - 'time time cross build --target="$TARGET" --locked' - - 'mv "target/$TARGET/debug/conduit" "conduit-debug-$TARGET"' - # print information about linking for debugging - - "file conduit-debug-$TARGET" # print file information - - 'readelf --dynamic conduit-debug-$TARGET | sed -e "/NEEDED/q1"' # ensure statically linked - artifacts: - expire_in: 4 weeks - -build:debug:cargo:x86_64-unknown-linux-musl: - extends: ".cargo-debug-shared-settings" - variables: - TARGET: "x86_64-unknown-linux-musl" - artifacts: - name: "conduit-debug-x86_64-unknown-linux-musl" - paths: - - "conduit-debug-x86_64-unknown-linux-musl" - expose_as: "Conduit DEBUG for x86_64-unknown-linux-musl" # --------------------------------------------------------------------- # # Create and publish docker image # @@ -146,98 +16,106 @@ build:debug:cargo:x86_64-unknown-linux-musl: .docker-shared-settings: stage: "build docker image" - image: jdrouet/docker-with-buildx:stable + image: jdrouet/docker-with-buildx:20.10.21-0.9.1 + needs: [] tags: ["docker"] + variables: + # Docker in Docker: + DOCKER_HOST: tcp://docker:2375/ + DOCKER_TLS_CERTDIR: "" + DOCKER_DRIVER: overlay2 services: - docker:dind - needs: - - "build:release:cargo:x86_64-unknown-linux-musl" - - "build:release:cargo:arm-unknown-linux-musleabihf" - - "build:release:cargo:armv7-unknown-linux-musleabihf" - - "build:release:cargo:aarch64-unknown-linux-musl" - variables: - PLATFORMS: "linux/arm/v6,linux/arm/v7,linux/arm64,linux/amd64" - DOCKER_FILE: "docker/ci-binaries-packaging.Dockerfile" - cache: - paths: - - docker_cache - key: "$CI_JOB_NAME" - before_script: - - docker login -u "$CI_REGISTRY_USER" -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY - # Only log in to Dockerhub if the credentials are given: - - if [ -n "${DOCKER_HUB}" ]; then docker login -u "$DOCKER_HUB_USER" -p "$DOCKER_HUB_PASSWORD" "$DOCKER_HUB"; fi script: - # Prepare buildx to build multiarch stuff: - - docker context create 'ci-context' - - docker buildx create --name 'multiarch-builder' --use 'ci-context' - # Copy binaries to their docker arch path - - mkdir -p linux/ && mv ./conduit-x86_64-unknown-linux-musl linux/amd64 - - mkdir -p linux/arm/ && mv ./conduit-arm-unknown-linux-musleabihf linux/arm/v6 - - mkdir -p linux/arm/ && mv ./conduit-armv7-unknown-linux-musleabihf linux/arm/v7 - - mv ./conduit-aarch64-unknown-linux-musl linux/arm64 - - 'export CREATED=$(date -u +''%Y-%m-%dT%H:%M:%SZ'') && echo "Docker image creation date: $CREATED"' - # Build and push image: + - apk add openssh-client + - eval $(ssh-agent -s) + - mkdir -p ~/.ssh && chmod 700 ~/.ssh + - printf "Host *\n\tStrictHostKeyChecking no\n\n" >> ~/.ssh/config + - sh .gitlab/setup-buildx-remote-builders.sh + # Authorize against this project's own image registry: + - docker login -u "$CI_REGISTRY_USER" -p "$CI_REGISTRY_PASSWORD" $CI_REGISTRY + # Build multiplatform image and push to temporary tag: - > - docker buildx build + docker buildx build + --platform "linux/arm/v7,linux/arm64,linux/amd64" --pull + --tag "$CI_REGISTRY_IMAGE/temporary-ci-images:$CI_JOB_ID" --push - --cache-from=type=local,src=$CI_PROJECT_DIR/docker_cache - --cache-to=type=local,dest=$CI_PROJECT_DIR/docker_cache - --build-arg CREATED=$CREATED - --build-arg VERSION=$(grep -m1 -o '[0-9].[0-9].[0-9]' Cargo.toml) - --build-arg "GIT_REF=$CI_COMMIT_SHORT_SHA" - --platform "$PLATFORMS" - --tag "$TAG" - --tag "$TAG-alpine" - --tag "$TAG-commit-$CI_COMMIT_SHORT_SHA" - --file "$DOCKER_FILE" . - -docker:next:gitlab: - extends: .docker-shared-settings - rules: - - if: '$CI_COMMIT_BRANCH == "next"' - variables: - TAG: "$CI_REGISTRY_IMAGE/matrix-conduit:next" - -docker:next:dockerhub: - extends: .docker-shared-settings - rules: - - if: '$CI_COMMIT_BRANCH == "next" && $DOCKER_HUB' - variables: - TAG: "$DOCKER_HUB_IMAGE/matrixconduit/matrix-conduit:next" + --file "Dockerfile" . + # Build multiplatform image to deb stage and extract their .deb files: + - > + docker buildx build + --platform "linux/arm/v7,linux/arm64,linux/amd64" + --target "packager-result" + --output="type=local,dest=/tmp/build-output" + --file "Dockerfile" . + # Build multiplatform image to binary stage and extract their binaries: + - > + docker buildx build + --platform "linux/arm/v7,linux/arm64,linux/amd64" + --target "builder-result" + --output="type=local,dest=/tmp/build-output" + --file "Dockerfile" . + # Copy to GitLab container registry: + - > + docker buildx imagetools create + --tag "$CI_REGISTRY_IMAGE/$TAG" + --tag "$CI_REGISTRY_IMAGE/$TAG-bullseye" + --tag "$CI_REGISTRY_IMAGE/$TAG-commit-$CI_COMMIT_SHORT_SHA" + "$CI_REGISTRY_IMAGE/temporary-ci-images:$CI_JOB_ID" + # if DockerHub credentials exist, also copy to dockerhub: + - if [ -n "${DOCKER_HUB}" ]; then docker login -u "$DOCKER_HUB_USER" -p "$DOCKER_HUB_PASSWORD" "$DOCKER_HUB"; fi + - > + if [ -n "${DOCKER_HUB}" ]; then + docker buildx imagetools create + --tag "$DOCKER_HUB_IMAGE/$TAG" + --tag "$DOCKER_HUB_IMAGE/$TAG-bullseye" + --tag "$DOCKER_HUB_IMAGE/$TAG-commit-$CI_COMMIT_SHORT_SHA" + "$CI_REGISTRY_IMAGE/temporary-ci-images:$CI_JOB_ID" + ; fi + - mv /tmp/build-output ./ + artifacts: + paths: + - "./build-output/" -docker:master:gitlab: +docker:next: extends: .docker-shared-settings rules: - - if: '$CI_COMMIT_BRANCH == "master"' + - if: '$BUILD_SERVER_SSH_PRIVATE_KEY && $CI_COMMIT_BRANCH == "next"' variables: - TAG: "$CI_REGISTRY_IMAGE/matrix-conduit:latest" + TAG: "matrix-conduit:next" -docker:master:dockerhub: +docker:master: extends: .docker-shared-settings rules: - - if: '$CI_COMMIT_BRANCH == "master" && $DOCKER_HUB' + - if: '$BUILD_SERVER_SSH_PRIVATE_KEY && $CI_COMMIT_BRANCH == "master"' variables: - TAG: "$DOCKER_HUB_IMAGE/matrixconduit/matrix-conduit:latest" + TAG: "matrix-conduit:latest" -docker:tags:gitlab: +docker:tags: extends: .docker-shared-settings rules: - - if: "$CI_COMMIT_TAG" + - if: "$BUILD_SERVER_SSH_PRIVATE_KEY && $CI_COMMIT_TAG" variables: - TAG: "$CI_REGISTRY_IMAGE/matrix-conduit:$CI_COMMIT_TAG" + TAG: "matrix-conduit:$CI_COMMIT_TAG" -docker:tags:dockerhub: - extends: .docker-shared-settings - rules: - - if: "$CI_COMMIT_TAG && $DOCKER_HUB" - variables: - TAG: "$DOCKER_HUB_IMAGE/matrixconduit/matrix-conduit:$CI_COMMIT_TAG" # --------------------------------------------------------------------- # # Run tests # # --------------------------------------------------------------------- # +cargo check: + stage: test + image: docker.io/rust:1.64.0-bullseye + needs: [] + interruptible: true + before_script: + - "rustup show && rustc --version && cargo --version" # Print version info for debugging + - apt-get update && apt-get -y --no-install-recommends install libclang-dev # dependency for rocksdb + script: + - cargo check + + .test-shared-settings: stage: "test" needs: [] @@ -250,8 +128,7 @@ docker:tags:dockerhub: test:cargo: extends: .test-shared-settings before_script: - # If provided, bring in caching through sccache, which uses an external S3 endpoint to store compilation results: - - if [ -n "${SCCACHE_ENDPOINT}" ]; then export RUSTC_WRAPPER=/usr/local/cargo/bin/sccache; fi + - apt-get update && apt-get -y --no-install-recommends install libclang-dev # dependency for rocksdb script: - rustc --version && cargo --version # Print version info for debugging - "cargo test --color always --workspace --verbose --locked --no-fail-fast -- -Z unstable-options --format json | gitlab-report -p test > $CI_PROJECT_DIR/report.xml" @@ -260,14 +137,12 @@ test:cargo: reports: junit: report.xml - test:clippy: extends: .test-shared-settings allow_failure: true before_script: - rustup component add clippy - # If provided, bring in caching through sccache, which uses an external S3 endpoint to store compilation results: - - if [ -n "${SCCACHE_ENDPOINT}" ]; then export RUSTC_WRAPPER=/usr/local/cargo/bin/sccache; fi + - apt-get update && apt-get -y --no-install-recommends install libclang-dev # dependency for rocksdb script: - rustc --version && cargo --version # Print version info for debugging - "cargo clippy --color always --verbose --message-format=json | gitlab-report -p clippy > $CI_PROJECT_DIR/gl-code-quality-report.json" @@ -294,38 +169,6 @@ test:audit: reports: sast: gl-sast-report.json -test:sytest: - stage: "test" - allow_failure: true - needs: - - "build:debug:cargo:x86_64-unknown-linux-musl" - image: - name: "valkum/sytest-conduit:latest" - entrypoint: [""] - tags: ["docker"] - variables: - PLUGINS: "https://github.com/valkum/sytest_conduit/archive/master.tar.gz" - interruptible: true - before_script: - - "mkdir -p /app" - - "cp ./conduit-debug-x86_64-unknown-linux-musl /app/conduit" - - "chmod +x /app/conduit" - - "rm -rf /src && ln -s $CI_PROJECT_DIR/ /src" - - "mkdir -p /work/server-0/database/ && mkdir -p /work/server-1/database/ && mkdir -p /work/server-2/database/" - - "cd /" - script: - - "SYTEST_EXIT_CODE=0" - - "/bootstrap.sh conduit || SYTEST_EXIT_CODE=1" - - 'perl /sytest/tap-to-junit-xml.pl --puretap --input /logs/results.tap --output $CI_PROJECT_DIR/sytest.xml "Sytest" && cp /logs/results.tap $CI_PROJECT_DIR/results.tap' - - "exit $SYTEST_EXIT_CODE" - artifacts: - when: always - paths: - - "$CI_PROJECT_DIR/sytest.xml" - - "$CI_PROJECT_DIR/results.tap" - reports: - junit: "$CI_PROJECT_DIR/sytest.xml" - test:dockerlint: stage: "test" needs: [] @@ -338,14 +181,12 @@ test:dockerlint: hadolint --no-fail --verbose ./Dockerfile - ./docker/ci-binaries-packaging.Dockerfile # Then output the results into a json for GitLab to pretty-print this in the MR: - > hadolint --format gitlab_codeclimate --failure-threshold error - ./Dockerfile - ./docker/ci-binaries-packaging.Dockerfile > dockerlint.json + ./Dockerfile > dockerlint.json artifacts: when: always reports: @@ -365,28 +206,26 @@ test:dockerlint: # Store binaries as package so they have download urls # # --------------------------------------------------------------------- # -publish:package: - stage: "upload artifacts" - needs: - - "build:release:cargo:x86_64-unknown-linux-musl" - - "build:release:cargo:arm-unknown-linux-musleabihf" - - "build:release:cargo:armv7-unknown-linux-musleabihf" - - "build:release:cargo:aarch64-unknown-linux-musl" - # - "build:cargo-deb:x86_64-unknown-linux-gnu" - rules: - - if: '$CI_COMMIT_BRANCH == "master"' - - if: '$CI_COMMIT_BRANCH == "next"' - - if: "$CI_COMMIT_TAG" - image: curlimages/curl:latest - tags: ["docker"] - variables: - GIT_STRATEGY: "none" # Don't need a clean copy of the code, we just operate on artifacts - script: - - 'BASE_URL="${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/generic/conduit-${CI_COMMIT_REF_SLUG}/build-${CI_PIPELINE_ID}"' - - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file conduit-x86_64-unknown-linux-musl "${BASE_URL}/conduit-x86_64-unknown-linux-musl"' - - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file conduit-arm-unknown-linux-musleabihf "${BASE_URL}/conduit-arm-unknown-linux-musleabihf"' - - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file conduit-armv7-unknown-linux-musleabihf "${BASE_URL}/conduit-armv7-unknown-linux-musleabihf"' - - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file conduit-aarch64-unknown-linux-musl "${BASE_URL}/conduit-aarch64-unknown-linux-musl"' +# DISABLED FOR NOW, NEEDS TO BE FIXED AT A LATER TIME: + +#publish:package: +# stage: "upload artifacts" +# needs: +# - "docker:tags" +# rules: +# - if: "$CI_COMMIT_TAG" +# image: curlimages/curl:latest +# tags: ["docker"] +# variables: +# GIT_STRATEGY: "none" # Don't need a clean copy of the code, we just operate on artifacts +# script: +# - 'BASE_URL="${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/generic/conduit-${CI_COMMIT_REF_SLUG}/build-${CI_PIPELINE_ID}"' +# - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file build-output/linux_amd64/conduit "${BASE_URL}/conduit-x86_64-unknown-linux-gnu"' +# - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file build-output/linux_arm_v7/conduit "${BASE_URL}/conduit-armv7-unknown-linux-gnu"' +# - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file build-output/linux_arm64/conduit "${BASE_URL}/conduit-aarch64-unknown-linux-gnu"' +# - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file build-output/linux_amd64/conduit.deb "${BASE_URL}/conduit-x86_64-unknown-linux-gnu.deb"' +# - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file build-output/linux_arm_v7/conduit.deb "${BASE_URL}/conduit-armv7-unknown-linux-gnu.deb"' +# - 'curl --header "JOB-TOKEN: $CI_JOB_TOKEN" --upload-file build-output/linux_arm64/conduit.deb "${BASE_URL}/conduit-aarch64-unknown-linux-gnu.deb"' # Avoid duplicate pipelines # See: https://docs.gitlab.com/ee/ci/yaml/workflow.html#switch-between-branch-pipelines-and-merge-request-pipelines diff --git a/.gitlab/CODEOWNERS b/.gitlab/CODEOWNERS new file mode 100644 index 0000000..665aaaa --- /dev/null +++ b/.gitlab/CODEOWNERS @@ -0,0 +1,5 @@ +# Nix things +.envrc @CobaltCause +flake.lock @CobaltCause +flake.nix @CobaltCause +nix/ @CobaltCause diff --git a/.gitlab/setup-buildx-remote-builders.sh b/.gitlab/setup-buildx-remote-builders.sh new file mode 100644 index 0000000..29d50dd --- /dev/null +++ b/.gitlab/setup-buildx-remote-builders.sh @@ -0,0 +1,37 @@ +#!/bin/sh +set -eux + +# --------------------------------------------------------------------- # +# # +# Configures docker buildx to use a remote server for arm building. # +# Expects $SSH_PRIVATE_KEY to be a valid ssh ed25519 private key with # +# access to the server $ARM_SERVER_USER@$ARM_SERVER_IP # +# # +# This is expected to only be used in the official CI/CD pipeline! # +# # +# Requirements: openssh-client, docker buildx # +# Inspired by: https://depot.dev/blog/building-arm-containers # +# # +# --------------------------------------------------------------------- # + +cat "$BUILD_SERVER_SSH_PRIVATE_KEY" | ssh-add - + +# Test server connections: +ssh "$ARM_SERVER_USER@$ARM_SERVER_IP" "uname -a" +ssh "$AMD_SERVER_USER@$AMD_SERVER_IP" "uname -a" + +# Connect remote arm64 server for all arm builds: +docker buildx create \ + --name "multi" \ + --driver "docker-container" \ + --platform "linux/arm64,linux/arm/v7" \ + "ssh://$ARM_SERVER_USER@$ARM_SERVER_IP" + +# Connect remote amd64 server for adm64 builds: +docker buildx create --append \ + --name "multi" \ + --driver "docker-container" \ + --platform "linux/amd64" \ + "ssh://$AMD_SERVER_USER@$AMD_SERVER_IP" + +docker buildx use multi diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 95294d4..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "rust-analyzer.procMacro.enable": true, -}
\ No newline at end of file @@ -9,60 +9,45 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] -name = "adler32" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" - -[[package]] name = "ahash" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ - "getrandom 0.2.7", + "getrandom 0.2.8", "once_cell", "version_check", ] [[package]] name = "aho-corasick" -version = "0.7.18" +version = "0.7.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" dependencies = [ "memchr", ] [[package]] name = "alloc-no-stdlib" -version = "2.0.3" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35ef4730490ad1c4eae5c4325b2a95f521d023e5c885853ff7aca0a6a1631db3" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" [[package]] name = "alloc-stdlib" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697ed7edc0f1711de49ce108c541623a0af97c6c60b2f6e2b65229847ac843c2" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" dependencies = [ "alloc-no-stdlib", ] [[package]] -name = "ansi_term" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" -dependencies = [ - "winapi", -] - -[[package]] name = "arc-swap" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5d78ce20460b82d3fa150275ed9d55e21064fc7951177baacf86a145c4a4b1f" +checksum = "983cd8b9d4b02a6dc6ffa557262eb5858a27a0038ffffe21a0f133eaa819a164" [[package]] name = "arrayref" @@ -72,9 +57,9 @@ checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" [[package]] name = "arrayvec" -version = "0.5.2" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" +checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "assign" @@ -84,9 +69,9 @@ checksum = "5f093eed78becd229346bf859eec0aa4dd7ddde0757287b2b4107a1f09c80002" [[package]] name = "async-compression" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345fd392ab01f746c717b1357165b76f0b67a60192007b234058c9045fdcf695" +checksum = "942c7cd7ae39e91bde4820d74132e9862e62c2f386c3aa90ccf55949f5bad63a" dependencies = [ "brotli", "flate2", @@ -98,9 +83,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.56" +version = "0.1.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96cf8829f67d2eab0b2dfa42c5d0ef737e0724e4a82b01b3e292456202b19716" +checksum = "1e805d94e6b5001b651426cf4cd446b1ab5f319d27bab5c644f61de0a804360c" dependencies = [ "proc-macro2", "quote", @@ -124,9 +109,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.5.8" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b4d4f9a5ca8b1ab8de59e663e68c6207059239373ca72980f5be7ab81231f74" +checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" dependencies = [ "async-trait", "axum-core", @@ -156,9 +141,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.2.6" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4d047478b986f14a13edad31a009e2e05cb241f9805d0d75e4cba4e129ad4d" +checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc" dependencies = [ "async-trait", "bytes", @@ -166,13 +151,15 @@ dependencies = [ "http", "http-body", "mime", + "tower-layer", + "tower-service", ] [[package]] name = "axum-server" -version = "0.4.0" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abf18303ef7e23b045301555bf8a0dfbc1444ea1a37b3c81757a32680ace4d7d" +checksum = "8456dab8f11484979a86651da8e619b355ede5d61a160755155f6c344bd18c47" dependencies = [ "arc-swap", "bytes", @@ -182,7 +169,7 @@ dependencies = [ "hyper", "pin-project-lite", "rustls", - "rustls-pemfile 1.0.0", + "rustls-pemfile 1.0.1", "tokio", "tokio-rustls", "tower-service", @@ -190,15 +177,21 @@ dependencies = [ [[package]] name = "base64" -version = "0.12.3" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.13.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" +checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" + +[[package]] +name = "base64ct" +version = "1.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b645a089122eccb6111b4f81cbc1a49f5900ac4666bb93ac027feaecf15607bf" [[package]] name = "bincode" @@ -236,9 +229,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "blake2b_simd" -version = "0.5.11" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587" +checksum = "72936ee4afc7f8f736d1c38383b56480b5497b4617b4a77bdbf1d2ababc76127" dependencies = [ "arrayref", "arrayvec", @@ -256,9 +249,9 @@ dependencies = [ [[package]] name = "block-buffer" -version = "0.10.2" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" dependencies = [ "generic-array", ] @@ -286,15 +279,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.10.0" +version = "3.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37ccbd214614c6783386c1af30caf03192f17891059cecc394b4fb119e363de3" +checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" [[package]] name = "bytemuck" -version = "1.9.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdead85bdec19c194affaeeb670c0e41fe23de31459efd1c174d049269cf02cc" +checksum = "aaa3a8d9a1ca92e282c96a32d6511b695d7d994d1d102ba85d279f9b2756947f" [[package]] name = "byteorder" @@ -304,15 +297,15 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.1.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +checksum = "dfb24e866b15a1af2a1b663f10c6b6b8f397a84aadb828f12e5b289ec23a3a3c" [[package]] name = "cc" -version = "1.0.73" +version = "1.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +checksum = "e9f73505338f7d905b19d18738976aae232eb46b8efc15554ffc56deb5d9ebe4" dependencies = [ "jobserver", ] @@ -328,34 +321,15 @@ dependencies = [ [[package]] name = "cfg-if" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" - -[[package]] -name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] -name = "chrono" -version = "0.4.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" -dependencies = [ - "libc", - "num-integer", - "num-traits", - "time", - "winapi", -] - -[[package]] name = "clang-sys" -version = "1.3.3" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a050e2153c5be08febd6734e29298e844fdb0fa21aeddd63b4eb7baa106c69b" +checksum = "fa2e27ae6ab525c3d369ded447057bca5438d86dc3a68f6faafb8269ba82ebf3" dependencies = [ "glob", "libc", @@ -364,23 +338,21 @@ dependencies = [ [[package]] name = "clap" -version = "3.2.5" +version = "4.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53da17d37dba964b9b3ecb5c5a1f193a2762c700e6829201e645b9381c99dc7" +checksum = "0acbd8d28a0a60d7108d7ae850af6ba34cf2d1257fc646980e5f97ce14275966" dependencies = [ "bitflags", "clap_derive", "clap_lex", - "indexmap", "once_cell", - "textwrap", ] [[package]] name = "clap_derive" -version = "3.2.5" +version = "4.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c11d40217d16aee8508cc8e5fde8b4ff24639758608e5374e731b53f85749fb9" +checksum = "0177313f9f02afc995627906bbd8967e2be069f5261954222dac78290c2b9014" dependencies = [ "heck", "proc-macro-error", @@ -391,9 +363,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5538cd660450ebeb4234cfecf8f2284b844ffc4c50531e66d584ad5b91293613" +checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8" dependencies = [ "os_str_bytes", ] @@ -406,11 +378,12 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "conduit" -version = "0.4.0" +version = "0.5.0" dependencies = [ + "async-trait", "axum", "axum-server", - "base64 0.13.0", + "base64 0.13.1", "bytes", "clap", "crossbeam", @@ -422,6 +395,7 @@ dependencies = [ "http", "image", "jsonwebtoken", + "lazy_static", "lru-cache", "num_cpus", "opentelemetry", @@ -436,11 +410,11 @@ dependencies = [ "ruma", "rusqlite", "rust-argon2", + "sd-notify", "serde", "serde_json", "serde_yaml", - "sha-1 0.9.8", - "sled", + "sha-1", "thiserror", "thread_local", "threadpool", @@ -451,15 +425,16 @@ dependencies = [ "tower-http", "tracing", "tracing-flame", + "tracing-opentelemetry", "tracing-subscriber", "trust-dns-resolver", ] [[package]] name = "const-oid" -version = "0.6.2" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6f2aa4d0537bcc1c74df8755072bd31c1ef1a3a1b85a68e8404a8c353b7b8b" +checksum = "cec318a675afcb6a1ea1d4340e2d377e56e47c266f28043ceccbf4412ddfdd3b" [[package]] name = "constant_time_eq" @@ -485,9 +460,9 @@ checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" [[package]] name = "cpufeatures" -version = "0.2.2" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" dependencies = [ "libc", ] @@ -513,122 +488,91 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", ] [[package]] name = "crossbeam" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae5588f6b3c3cb05239e90bd110f257254aecd01e4635400391aeae07497845" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "crossbeam-channel", "crossbeam-deque", "crossbeam-epoch", - "crossbeam-queue 0.3.5", - "crossbeam-utils 0.8.9", + "crossbeam-queue", + "crossbeam-utils", ] [[package]] name = "crossbeam-channel" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c02a4d71819009c192cf4872265391563fd6a84c81ff2c0f2a7026ca4c1d85c" +checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" dependencies = [ - "cfg-if 1.0.0", - "crossbeam-utils 0.8.9", + "cfg-if", + "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" +checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "crossbeam-epoch", - "crossbeam-utils 0.8.9", + "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.9" +version = "0.9.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07db9d94cbd326813772c968ccd25999e5f8ae22f4f8d1b11effa37ef6ce281d" +checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" dependencies = [ "autocfg", - "cfg-if 1.0.0", - "crossbeam-utils 0.8.9", + "cfg-if", + "crossbeam-utils", "memoffset", - "once_cell", "scopeguard", ] [[package]] name = "crossbeam-queue" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c979cd6cfe72335896575c6b5688da489e420d36a27a0b9eb0c73db574b4a4b" -dependencies = [ - "crossbeam-utils 0.6.6", -] - -[[package]] -name = "crossbeam-queue" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f25d8400f4a7a5778f0e4e52384a48cbd9b5c495d110786187fc750075277a2" -dependencies = [ - "cfg-if 1.0.0", - "crossbeam-utils 0.8.9", -] - -[[package]] -name = "crossbeam-utils" -version = "0.6.6" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04973fa96e96579258a5091af6003abde64af786b860f18622b82e026cca60e6" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" dependencies = [ - "cfg-if 0.1.10", - "lazy_static", + "cfg-if", + "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.9" +version = "0.8.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ff1f980957787286a554052d03c7aee98d99cc32e09f6d45f0a814133c87978" +checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" dependencies = [ - "cfg-if 1.0.0", - "once_cell", + "cfg-if", ] [[package]] name = "crypto-common" -version = "0.1.3" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", "typenum", ] [[package]] -name = "crypto-mac" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1d1a86f49236c215f271d40892d5fc950490551400b02ef360692c29815c714" -dependencies = [ - "generic-array", - "subtle", -] - -[[package]] name = "curve25519-dalek" -version = "3.2.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90f9d052967f590a76e62eb387bd0bbb1b000182c3cefe5364db6b7211651bc0" +checksum = "0b9fdf9972b2bd6af2d913799d9ebc165ea4d2e65878e329d9c6b372c4491b61" dependencies = [ "byteorder", "digest 0.9.0", @@ -638,28 +582,32 @@ dependencies = [ ] [[package]] -name = "data-encoding" -version = "2.3.2" +name = "dashmap" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] [[package]] -name = "deflate" -version = "0.8.6" +name = "data-encoding" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73770f8e1fe7d64df17ca66ad28994a0a623ea497fa69486e14984e715c5d174" -dependencies = [ - "adler32", - "byteorder", -] +checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" [[package]] name = "der" -version = "0.4.5" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79b71cca7d95d7681a4b3b9cdf63c8dbc3730d0584c2c74e31416d64a90493f4" +checksum = "13dd2ae565c0a381dde7fade45fce95984c568bdcb4700a4fdbe3175e0380b2f" dependencies = [ "const-oid", + "zeroize", ] [[package]] @@ -673,12 +621,13 @@ dependencies = [ [[package]] name = "digest" -version = "0.10.3" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ - "block-buffer 0.10.2", + "block-buffer 0.10.3", "crypto-common", + "subtle", ] [[package]] @@ -726,9 +675,9 @@ dependencies = [ [[package]] name = "either" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" +checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797" [[package]] name = "encoding_rs" @@ -736,14 +685,14 @@ version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", ] [[package]] name = "enum-as-inner" -version = "0.3.4" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "570d109b813e904becc80d8d5da38376818a143348413f7149f1340fe04754d4" +checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116" dependencies = [ "heck", "proc-macro2", @@ -765,9 +714,9 @@ checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" [[package]] name = "figment" -version = "0.10.6" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "790b4292c72618abbab50f787a477014fe15634f96291de45672ce46afe122df" +checksum = "4e56602b469b2201400dec66a66aec5a9b8761ee97cd1b8c96ab2483fcc16cc9" dependencies = [ "atomic", "pear", @@ -779,12 +728,12 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6" +checksum = "a8a2db397cb1c8772f31494cb8917e48cd1e64f0fa7efac59fbd741a0a8ce841" dependencies = [ "crc32fast", - "miniz_oxide 0.5.3", + "miniz_oxide", ] [[package]] @@ -795,11 +744,10 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fc25a87fa4fd2094bffb06925852034d90a17f0d1e05197d4956d3555752191" +checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" dependencies = [ - "matches", "percent-encoding", ] @@ -821,9 +769,9 @@ checksum = "2022715d62ab30faffd124d40b76f4134a550a87792276512b18d63272333394" [[package]] name = "futures" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f73fe65f54d1e12b726f517d3e2135ca3125a437b6d998caf1962961f7172d9e" +checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" dependencies = [ "futures-channel", "futures-core", @@ -836,9 +784,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3083ce4b914124575708913bca19bfe887522d6e2e6d0952943f5eac4a74010" +checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed" dependencies = [ "futures-core", "futures-sink", @@ -846,15 +794,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" +checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" [[package]] name = "futures-executor" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9420b90cfa29e327d0429f19be13e7ddb68fa1cccb09d65e5706b8c7a749b8a6" +checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" dependencies = [ "futures-core", "futures-task", @@ -863,15 +811,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b" +checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" [[package]] name = "futures-macro" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512" +checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" dependencies = [ "proc-macro2", "quote", @@ -880,21 +828,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" +checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9" [[package]] name = "futures-task" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" +checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea" [[package]] name = "futures-util" -version = "0.3.21" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" +checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" dependencies = [ "futures-channel", "futures-core", @@ -909,19 +857,10 @@ dependencies = [ ] [[package]] -name = "fxhash" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" -dependencies = [ - "byteorder", -] - -[[package]] name = "generic-array" -version = "0.14.5" +version = "0.14.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" +checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" dependencies = [ "typenum", "version_check", @@ -933,27 +872,27 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "libc", "wasi 0.9.0+wasi-snapshot-preview1", ] [[package]] name = "getrandom" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "libc", "wasi 0.11.0+wasi-snapshot-preview1", ] [[package]] name = "gif" -version = "0.11.3" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a7187e78088aead22ceedeee99779455b23fc231fe13ec443f99bb71694e5b" +checksum = "3edd93c6756b4dfaf2709eafcc345ba2636565295c198a9cfbf75fa5e3e00b06" dependencies = [ "color_quant", "weezl", @@ -967,9 +906,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" [[package]] name = "h2" -version = "0.3.13" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37a82c6d637fc9515a4694bbf1cb2457b79d81ce52b3108bdeea58b07dd34a57" +checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" dependencies = [ "bytes", "fnv", @@ -986,42 +925,36 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.11.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" dependencies = [ "ahash", ] [[package]] -name = "hashbrown" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db0d4cf898abf0081f964436dc980e96670a0f36863e4b83aaacdb65c9d7ccc3" - -[[package]] name = "hashlink" -version = "0.7.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" +checksum = "69fe1fcf8b4278d860ad0548329f892a3631fb63f82574df68275f34cdbe0ffa" dependencies = [ - "hashbrown 0.11.2", + "hashbrown", ] [[package]] name = "headers" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cff78e5788be1e0ab65b04d306b2ed5092c815ec97ec70f4ebd5aee158aa55d" +checksum = "f3e372db8e5c0d213e0cd0b9be18be2aca3d44cf2fe30a9d46a65581cd454584" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", "bitflags", "bytes", "headers-core", "http", "httpdate", "mime", - "sha-1 0.10.0", + "sha1", ] [[package]] @@ -1086,12 +1019,11 @@ dependencies = [ [[package]] name = "hmac" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a2a2320eb7ec0ebe8da8f744d7812d9fc4cb4d09344ac01898dbcb6a20ae69b" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" dependencies = [ - "crypto-mac", - "digest 0.9.0", + "digest 0.10.6", ] [[package]] @@ -1135,9 +1067,9 @@ checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" [[package]] name = "httparse" -version = "1.7.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "496ce29bb5a52785b44e0f7ca2847ae0bb839c9bd28f69acac9b99d461c0c04c" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "httpdate" @@ -1147,9 +1079,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.19" +version = "0.14.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42dc3c131584288d375f2d07f822b0cb012d8c6fb899a5b9fdb3cb7eb9b6004f" +checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c" dependencies = [ "bytes", "futures-channel", @@ -1162,7 +1094,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.4", + "socket2", "tokio", "tower-service", "tracing", @@ -1171,9 +1103,9 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.23.0" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac" +checksum = "59df7c4e19c950e6e0e868dcc0a300b09a9b88e9ec55bd879ca819087a77355d" dependencies = [ "http", "hyper", @@ -1194,17 +1126,26 @@ dependencies = [ ] [[package]] +name = "idna" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] name = "image" -version = "0.23.14" +version = "0.24.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24ffcb7e7244a9bf19d35bf2883b9c080c4ced3c07a9895572178cdb8f13f6a1" +checksum = "69b7ea949b537b0fd0af141fff8c77690f2ce96f4f41f042ccb6c69c6c965945" dependencies = [ "bytemuck", "byteorder", "color_quant", "gif", "jpeg-decoder", - "num-iter", "num-rational", "num-traits", "png", @@ -1212,95 +1153,80 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.0" +version = "1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c6392766afd7964e2531940894cffe4bd8d7d17dbc3c1c4857040fd4b33bdb3" +checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" dependencies = [ "autocfg", - "hashbrown 0.12.1", + "hashbrown", "serde", ] [[package]] -name = "indoc" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05a0bd019339e5d968b37855180087b7b9d512c5046fbd244cf8c95687927d6e" - -[[package]] name = "inlinable_string" version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" [[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if 1.0.0", -] - -[[package]] name = "integer-encoding" -version = "1.1.7" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48dc51180a9b377fd75814d0cc02199c20f8e99433d6762f650d39cdbbd3b56f" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipconfig" -version = "0.2.2" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7e2f18aece9709094573a9f24f483c4f65caa4298e2f7ae1b71cc65d853fad7" +checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be" dependencies = [ - "socket2 0.3.19", + "socket2", "widestring", "winapi", - "winreg 0.6.2", + "winreg 0.10.1", ] [[package]] name = "ipnet" -version = "2.5.0" +version = "2.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" +checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745" [[package]] name = "itertools" -version = "0.10.3" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9a9d19fa1e79b6215ff29b9d6880b706147f16e9b1dbb1e4e5947b5b02bc5e3" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" dependencies = [ "either", ] [[package]] name = "itoa" -version = "1.0.2" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112c678d4050afce233f4f2852bb2eb519230b3cf12f33585275537d7e41578d" +checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc" [[package]] name = "jobserver" -version = "0.1.24" +version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af25a77299a7f711a01975c35a6a424eb6862092cc2d6c72c4ed6cbc56dfc1fa" +checksum = "068b1ee6743e4d11fb9c6a1e6064b3693a1b600e7f5f5988047d98b3dc9fb90b" dependencies = [ "libc", ] [[package]] name = "jpeg-decoder" -version = "0.1.22" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "229d53d58899083193af11e15917b5640cd40b29ff475a1fe4ef725deb02d0f2" +checksum = "bc0000e42512c92e31c2252315bda326620a4e034105e900c98ec492fa077b3e" [[package]] name = "js-sys" -version = "0.3.58" +version = "0.3.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3fac17f7123a73ca62df411b1bf727ccc805daa070338fda671c86dac1bdc27" +checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" dependencies = [ "wasm-bindgen", ] @@ -1315,12 +1241,21 @@ dependencies = [ ] [[package]] +name = "js_option" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68421373957a1593a767013698dbf206e2b221eefe97a44d98d18672ff38423c" +dependencies = [ + "serde", +] + +[[package]] name = "jsonwebtoken" -version = "7.2.0" +version = "8.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afabcc15e437a6484fc4f12d0fd63068fe457bf93f1c148d3d9649c60b103f32" +checksum = "1aa4b4af834c6cfd35d8763d359661b90f2e45d8f750a0849156c7f4671af09c" dependencies = [ - "base64 0.12.3", + "base64 0.13.1", "pem", "ring", "serde", @@ -1329,6 +1264,28 @@ dependencies = [ ] [[package]] +name = "konst" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "330f0e13e6483b8c34885f7e6c9f19b1a7bd449c673fbb948a51c99d66ef74f4" +dependencies = [ + "konst_macro_rules", + "konst_proc_macros", +] + +[[package]] +name = "konst_macro_rules" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4933f3f57a8e9d9da04db23fb153356ecaf00cbd14aee46279c33dc80925c37" + +[[package]] +name = "konst_proc_macros" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "984e109462d46ad18314f10e392c286c3d47bce203088a09012de1015b45b737" + +[[package]] name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1342,17 +1299,17 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.126" +version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" [[package]] name = "libloading" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efbc0f03f9a775e9f6aed295c6a1ba2253c5757a9e03d55c6caa46a681abcddd" +checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "winapi", ] @@ -1370,9 +1327,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.22.2" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290b64917f8b0cb885d9de0f9959fe1f775d7fa12f1da2db9001c1c8ab60f89d" +checksum = "29f835d03d717946d28b1d1ed632eb6f0e24a299388ee623d0c23118d3e8a7fa" dependencies = [ "cc", "pkg-config", @@ -1381,9 +1338,9 @@ dependencies = [ [[package]] name = "linked-hash-map" -version = "0.5.4" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fb9b38af92608140b86b693604b9ffcc5824240a484d1ecd4795bacb2fe88f3" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "lmdb-rkv-sys" @@ -1398,9 +1355,9 @@ dependencies = [ [[package]] name = "lock_api" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53" +checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" dependencies = [ "autocfg", "scopeguard", @@ -1412,7 +1369,7 @@ version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", ] [[package]] @@ -1438,9 +1395,9 @@ checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" [[package]] name = "matchers" -version = "0.0.1" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" dependencies = [ "regex-automata", ] @@ -1465,9 +1422,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "memoffset" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" dependencies = [ "autocfg", ] @@ -1486,32 +1443,23 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "791daaae1ed6889560f8c4359194f56648355540573244a5448a83ba1ecc7435" -dependencies = [ - "adler32", -] - -[[package]] -name = "miniz_oxide" -version = "0.5.3" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f5c75688da582b8ffc1f1799e9db273f32133c49e048f614d22ec3256773ccc" +checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" dependencies = [ "adler", ] [[package]] name = "mio" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713d550d9b44d89174e066b7a6217ae06234c10cb47819a88290d2b353c31799" +checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1525,42 +1473,41 @@ dependencies = [ ] [[package]] -name = "num-bigint" -version = "0.2.6" +name = "nu-ansi-term" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" dependencies = [ - "autocfg", - "num-integer", - "num-traits", + "overload", + "winapi", ] [[package]] -name = "num-integer" -version = "0.1.45" +name = "num-bigint" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" dependencies = [ "autocfg", + "num-integer", "num-traits", ] [[package]] -name = "num-iter" -version = "0.1.43" +name = "num-integer" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" dependencies = [ "autocfg", - "num-integer", "num-traits", ] [[package]] name = "num-rational" -version = "0.3.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12ac428b1cb17fce6f731001d307d351ec70a6d202fc2e60f7d4c5e42d8f4f07" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" dependencies = [ "autocfg", "num-integer", @@ -1578,9 +1525,9 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" +checksum = "f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5" dependencies = [ "hermit-abi", "libc", @@ -1588,9 +1535,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.12.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7709cef83f0c1f58f666e746a08b21e0085f7440fa6a29cc194d68aac97a4225" +checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" [[package]] name = "opaque-debug" @@ -1606,31 +1553,24 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "opentelemetry" -version = "0.16.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf9b1c4e9a6c4de793c632496fa490bdc0e1eea73f0c91394f7b6990935d22" +checksum = "69d6c3d7288a106c0a363e4b0e8d308058d56902adefb16f4936f417ffef086e" dependencies = [ - "async-trait", - "crossbeam-channel", - "futures", - "js-sys", - "lazy_static", - "percent-encoding", - "pin-project", - "rand 0.8.5", - "thiserror", - "tokio", - "tokio-stream", + "opentelemetry_api", + "opentelemetry_sdk", ] [[package]] name = "opentelemetry-jaeger" -version = "0.15.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db22f492873ea037bc267b35a0e8e4fb846340058cb7c864efe3d0bf23684593" +checksum = "1e785d273968748578931e4dc3b4f5ec86b26e09d9e0d66b55adda7fce742f7a" dependencies = [ "async-trait", - "lazy_static", + "futures", + "futures-executor", + "once_cell", "opentelemetry", "opentelemetry-semantic-conventions", "thiserror", @@ -1640,14 +1580,52 @@ dependencies = [ [[package]] name = "opentelemetry-semantic-conventions" -version = "0.8.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffeac823339e8b0f27b961f4385057bf9f97f2863bc745bd015fd6091f2270e9" +checksum = "9b02e0230abb0ab6636d18e2ba8fa02903ea63772281340ccac18e0af3ec9eeb" dependencies = [ "opentelemetry", ] [[package]] +name = "opentelemetry_api" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c24f96e21e7acc813c7a8394ee94978929db2bcc46cf6b5014fc612bf7760c22" +dependencies = [ + "fnv", + "futures-channel", + "futures-util", + "indexmap", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ca41c4933371b61c2a2f214bf16931499af4ec90543604ec828f7a625c09113" +dependencies = [ + "async-trait", + "crossbeam-channel", + "dashmap", + "fnv", + "futures-channel", + "futures-executor", + "futures-util", + "once_cell", + "opentelemetry_api", + "percent-encoding", + "rand 0.8.5", + "thiserror", + "tokio", + "tokio-stream", +] + +[[package]] name = "ordered-float" version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1658,9 +1636,15 @@ dependencies = [ [[package]] name = "os_str_bytes" -version = "6.1.0" +version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21326818e99cfe6ce1e524c2a805c189a99b5ae555a35d19f9a284b427d86afa" +checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "page_size" @@ -1674,34 +1658,32 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.11.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ - "instant", "lock_api", "parking_lot_core", ] [[package]] name = "parking_lot_core" -version = "0.8.5" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" +checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0" dependencies = [ - "cfg-if 1.0.0", - "instant", + "cfg-if", "libc", "redox_syscall", "smallvec", - "winapi", + "windows-sys 0.42.0", ] [[package]] name = "paste" -version = "1.0.7" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c520e05135d6e763148b6426a837e239041653ba7becd2e538c076c738025fc" +checksum = "b1de2e551fb905ac83f73f7aedf2f0cb4a0da7e35efa24a202a936269f1f18e1" [[package]] name = "pear" @@ -1734,26 +1716,24 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" [[package]] name = "pem" -version = "0.8.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd56cbd21fea48d0c440b41cd69c589faacade08c992d9a54e471b79d0fd13eb" +checksum = "03c64931a1a212348ec4f3b4362585eca7159d0d09cbdf4a7f74f02173596fd4" dependencies = [ - "base64 0.13.0", - "once_cell", - "regex", + "base64 0.13.1", ] [[package]] name = "percent-encoding" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "persy" -version = "1.2.6" +version = "1.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5af61053f1daed3ff0265fad7f924e43ce07642a336c79304f8e5aec205460fb" +checksum = "5511189f4dbd737283b0dd2ff6715f2e35fd0d3e1ddf953ed6a772e439e1f73f" dependencies = [ "crc", "data-encoding", @@ -1767,18 +1747,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.0.10" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58ad3879ad3baf4e44784bc6a718a8698867bb991f8ce24d1bcbe2cfb4c3a75e" +checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.0.10" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744b6f092ba29c3650faf274db506afd39944f48420f6c86b17cfe0ee1cb36bb" +checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" dependencies = [ "proc-macro2", "quote", @@ -1799,45 +1779,45 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkcs8" -version = "0.7.6" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee3ef9b64d26bad0536099c816c6734379e45bbd5f14798def6809e5cc350447" +checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" dependencies = [ "der", "spki", - "zeroize", ] [[package]] name = "pkg-config" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" +checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" [[package]] name = "png" -version = "0.16.8" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c3287920cb847dee3de33d301c463fba14dda99db24214ddf93f83d3021f4c6" +checksum = "5d708eaf860a19b19ce538740d2b4bdeeb8337fa53f7738455e706623ad5c638" dependencies = [ "bitflags", "crc32fast", - "deflate", - "miniz_oxide 0.3.7", + "flate2", + "miniz_oxide", ] [[package]] name = "ppv-lite86" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro-crate" -version = "1.1.3" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e17d47ce914bf4de440332250b0edd23ce48c005f59fab39d3335866b114f11a" +checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9" dependencies = [ + "once_cell", "thiserror", "toml", ] @@ -1868,9 +1848,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.39" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c54b25569025b7fc9651de43004ae593a75ad88543b17178aa5e1b9c4f15f56f" +checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" dependencies = [ "unicode-ident", ] @@ -1896,9 +1876,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.19" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f53dc8cf16a769a6f677e09e7ff2cd4be1ea0f48754aac39520536962011de0d" +checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" dependencies = [ "proc-macro2", ] @@ -1924,7 +1904,7 @@ checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", "rand_chacha 0.3.1", - "rand_core 0.6.3", + "rand_core 0.6.4", ] [[package]] @@ -1944,7 +1924,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.3", + "rand_core 0.6.4", ] [[package]] @@ -1958,11 +1938,11 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.7", + "getrandom 0.2.8", ] [[package]] @@ -1976,9 +1956,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.13" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f25bc4c7e55e0b0b7a1d43fb893f4fa1361d0abe38b9ce4f323c2adfe6ef42" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ "bitflags", ] @@ -1989,16 +1969,16 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ - "getrandom 0.2.7", + "getrandom 0.2.8", "redox_syscall", "thiserror", ] [[package]] name = "regex" -version = "1.5.6" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d83f127d94bdbcda4c8cc2e50f6f84f4b611f69c902699ca385a39c3a75f9ff1" +checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a" dependencies = [ "aho-corasick", "memchr", @@ -2016,16 +1996,16 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.26" +version = "0.6.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b3de9ec5dc0a3417da371aab17d729997c15010e7fd24ff707773a33bddb64" +checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" [[package]] name = "reqwest" version = "0.11.9" source = "git+https://github.com/timokoesters/reqwest?rev=57b7cf4feb921573dfafad7d34b9ac6e44ead0bd#57b7cf4feb921573dfafad7d34b9ac6e44ead0bd" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", "bytes", "encoding_rs", "futures-core", @@ -2095,11 +2075,12 @@ dependencies = [ [[package]] name = "ruma" -version = "0.5.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.7.4" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ "assign", "js_int", + "js_option", "ruma-appservice-api", "ruma-client-api", "ruma-common", @@ -2112,9 +2093,10 @@ dependencies = [ [[package]] name = "ruma-appservice-api" -version = "0.5.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.7.0" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ + "js_int", "ruma-common", "serde", "serde_json", @@ -2122,13 +2104,14 @@ dependencies = [ [[package]] name = "ruma-client-api" -version = "0.13.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.15.3" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ "assign", "bytes", "http", "js_int", + "js_option", "maplit", "percent-encoding", "ruma-common", @@ -2138,19 +2121,21 @@ dependencies = [ [[package]] name = "ruma-common" -version = "0.8.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.10.5" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ - "base64 0.13.0", + "base64 0.20.0", "bytes", "form_urlencoded", "http", "indexmap", - "indoc", "itoa", "js_int", + "js_option", + "konst", "percent-encoding", "rand 0.8.5", + "regex", "ruma-identifiers-validation", "ruma-macros", "serde", @@ -2164,8 +2149,8 @@ dependencies = [ [[package]] name = "ruma-federation-api" -version = "0.4.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.6.0" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ "js_int", "ruma-common", @@ -2175,17 +2160,17 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" -version = "0.7.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.9.0" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ + "js_int", "thiserror", - "url", ] [[package]] name = "ruma-identity-service-api" -version = "0.4.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.6.0" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ "js_int", "ruma-common", @@ -2194,20 +2179,23 @@ dependencies = [ [[package]] name = "ruma-macros" -version = "0.1.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.10.5" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ + "once_cell", "proc-macro-crate", "proc-macro2", "quote", "ruma-identifiers-validation", + "serde", "syn", + "toml", ] [[package]] name = "ruma-push-gateway-api" -version = "0.4.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.6.0" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ "js_int", "ruma-common", @@ -2217,24 +2205,24 @@ dependencies = [ [[package]] name = "ruma-signatures" -version = "0.10.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.12.0" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ - "base64 0.13.0", + "base64 0.20.0", "ed25519-dalek", "pkcs8", "rand 0.7.3", "ruma-common", "serde_json", "sha2", + "subslice", "thiserror", - "tracing", ] [[package]] name = "ruma-state-res" -version = "0.6.0" -source = "git+https://github.com/ruma/ruma?rev=d614ad1422d6c4b3437ebc318ca8514ae338fd6d#d614ad1422d6c4b3437ebc318ca8514ae338fd6d" +version = "0.8.0" +source = "git+https://github.com/ruma/ruma?rev=67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26#67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26" dependencies = [ "itertools", "js_int", @@ -2247,29 +2235,28 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.25.4" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c4b1eaf239b47034fb450ee9cdedd7d0226571689d8823030c4b6c2cb407152" +checksum = "01e213bc3ecb39ac32e81e51ebe31fd888a940515173e3a18a35f8c6e896422a" dependencies = [ "bitflags", "fallible-iterator", "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", - "memchr", "smallvec", ] [[package]] name = "rust-argon2" -version = "0.8.3" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb" +checksum = "b50162d19404029c1ceca6f6980fe40d45c8b369f6f44446fa14bb39573b5bb9" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", "blake2b_simd", "constant_time_eq", - "crossbeam-utils 0.8.9", + "crossbeam-utils", ] [[package]] @@ -2280,9 +2267,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustls" -version = "0.20.6" +version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" dependencies = [ "log", "ring", @@ -2297,7 +2284,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" dependencies = [ "openssl-probe", - "rustls-pemfile 1.0.0", + "rustls-pemfile 1.0.1", "schannel", "security-framework", ] @@ -2308,23 +2295,23 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", ] [[package]] name = "rustls-pemfile" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7522c9de787ff061458fe9a829dc790a3f5b22dc571694fc5883f448b94d9a9" +checksum = "0864aeff53f8c05aa08d86e5ef839d3dfcf07aeba2db32f12db0ef716e87bd55" dependencies = [ - "base64 0.13.0", + "base64 0.13.1", ] [[package]] name = "ryu" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3f6f92acf49d1b98f7a81226834412ada05458b7364277387724a237f062695" +checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" [[package]] name = "schannel" @@ -2333,7 +2320,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2" dependencies = [ "lazy_static", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -2353,10 +2340,16 @@ dependencies = [ ] [[package]] +name = "sd-notify" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "621e3680f3e07db4c9c2c3fb07c6223ab2fab2e54bd3c04c3ae037990f428c32" + +[[package]] name = "security-framework" -version = "2.6.1" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dc14f172faf8a0194a3aded622712b0de276821addc574fa54fc0a1167e10dc" +checksum = "2bc1bb97804af6631813c55739f771071e0f2ed33ee20b68c86ec505d906356c" dependencies = [ "bitflags", "core-foundation", @@ -2377,18 +2370,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.137" +version = "1.0.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61ea8d54c77f8315140a05f4c7237403bf38b72704d031543aa1d16abbf517d1" +checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.137" +version = "1.0.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f26faba0c3959972377d3b2d306ee9f71faee9714294e41bb777f83f88578be" +checksum = "4f1d362ca8fc9c3e3a7484440752472d68a6caa98f1ab81d99b5dfe517cec852" dependencies = [ "proc-macro2", "quote", @@ -2397,9 +2390,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.81" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7ce2b32a1aed03c558dc61a5cd328f15aff2dbc17daad8fb8af04d2100e15c" +checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db" dependencies = [ "itoa", "ryu", @@ -2420,38 +2413,37 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.8.24" +version = "0.9.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707d15895415db6628332b737c838b88c598522e4dc70647e59b72312924aebc" +checksum = "6d232d893b10de3eb7258ff01974d6ee20663d8e833263c99409d4b13a0209da" dependencies = [ "indexmap", + "itoa", "ryu", "serde", - "yaml-rust", + "unsafe-libyaml", ] [[package]] name = "sha-1" -version = "0.9.8" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" +checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" dependencies = [ - "block-buffer 0.9.0", - "cfg-if 1.0.0", + "cfg-if", "cpufeatures", - "digest 0.9.0", - "opaque-debug", + "digest 0.10.6", ] [[package]] -name = "sha-1" -version = "0.10.0" +name = "sha1" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" +checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "cpufeatures", - "digest 0.10.3", + "digest 0.10.6", ] [[package]] @@ -2461,7 +2453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" dependencies = [ "block-buffer 0.9.0", - "cfg-if 1.0.0", + "cfg-if", "cpufeatures", "digest 0.9.0", "opaque-debug", @@ -2493,66 +2485,42 @@ dependencies = [ [[package]] name = "signature" -version = "1.5.0" +version = "1.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f054c6c1a6e95179d6f23ed974060dcefb2d9388bb7256900badad682c499de4" +checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c" [[package]] name = "simple_asn1" -version = "0.4.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "692ca13de57ce0613a363c8c2f1de925adebc81b04c923ac60c5488bb44abe4b" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" dependencies = [ - "chrono", "num-bigint", "num-traits", + "thiserror", + "time", ] [[package]] name = "slab" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" - -[[package]] -name = "sled" -version = "0.34.7" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935" +checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" dependencies = [ - "crc32fast", - "crossbeam-epoch", - "crossbeam-utils 0.8.9", - "fs2", - "fxhash", - "libc", - "log", - "parking_lot", - "zstd", + "autocfg", ] [[package]] name = "smallvec" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" +checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.3.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e" -dependencies = [ - "cfg-if 1.0.0", - "libc", - "winapi", -] - -[[package]] -name = "socket2" -version = "0.4.4" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66d72b759436ae32898a2af0a14218dbf55efde3feeb170eb623637db85ee1e0" +checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" dependencies = [ "libc", "winapi", @@ -2566,14 +2534,24 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "spki" -version = "0.4.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c01a0c15da1b0b0e1494112e7af814a678fec9bd157881b49beac661e9b6f32" +checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b" dependencies = [ + "base64ct", "der", ] [[package]] +name = "subslice" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0a8e4809a3bb02de01f1f7faf1ba01a83af9e8eabcd4d31dd6e413d14d56aae" +dependencies = [ + "memchr", +] + +[[package]] name = "subtle" version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2581,9 +2559,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.98" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c50aef8a904de4c23c788f104b7dddc7d6f79c647c7c8ce4cc8f73eb0ca773dd" +checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d" dependencies = [ "proc-macro2", "quote", @@ -2598,11 +2576,11 @@ checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" [[package]] name = "synchronoise" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d717ed0efc9d39ab3b642a096bc369a3e02a38a51c41845d7fe31bdad1d6eaeb" +checksum = "3dbc01390fc626ce8d1cffe3376ded2b72a11bb70e1c75f404a210e4daa4def2" dependencies = [ - "crossbeam-queue 0.1.2", + "crossbeam-queue", ] [[package]] @@ -2618,25 +2596,19 @@ dependencies = [ ] [[package]] -name = "textwrap" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1141d4d61095b28419e22cb0bbf02755f5e54e0526f97f1e3d1d160e60885fb" - -[[package]] name = "thiserror" -version = "1.0.31" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd829fe32373d27f76265620b5309d0340cb8550f523c1dda251d6298069069a" +checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.31" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0396bc89e626244658bef819e22d0cc459e795a5ebe878e6ec336d1674a8d79a" +checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" dependencies = [ "proc-macro2", "quote", @@ -2663,9 +2635,9 @@ dependencies = [ [[package]] name = "thrift" -version = "0.13.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c6d965454947cc7266d22716ebfd07b18d84ebaf35eec558586bbb2a8cb6b5b" +checksum = "09678c4cdbb4eed72e18b7c2af1329c69825ed16fcbac62d083fc3e2b0590ff0" dependencies = [ "byteorder", "integer-encoding", @@ -2676,9 +2648,9 @@ dependencies = [ [[package]] name = "tikv-jemalloc-ctl" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb833c46ecbf8b6daeccb347cefcabf9c1beb5c9b0f853e1cec45632d9963e69" +checksum = "e37706572f4b151dff7a0146e040804e9c26fe3a3118591112f05cf12a4216c1" dependencies = [ "libc", "paste", @@ -2687,9 +2659,9 @@ dependencies = [ [[package]] name = "tikv-jemalloc-sys" -version = "0.4.3+5.2.1-patched.2" +version = "0.5.2+5.3.0-patched" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1792ccb507d955b46af42c123ea8863668fae24d03721e40cad6a41773dbb49" +checksum = "ec45c14da997d0925c7835883e4d5c181f196fa142f8c19d7643d1e9af2592c3" dependencies = [ "cc", "fs_extra", @@ -2698,9 +2670,9 @@ dependencies = [ [[package]] name = "tikv-jemallocator" -version = "0.4.3" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5b7bcecfafe4998587d636f9ae9d55eb9d0499877b88757767c346875067098" +checksum = "20612db8a13a6c06d57ec83953694185a367e16945f66565e8028d2c0bd76979" dependencies = [ "libc", "tikv-jemalloc-sys", @@ -2708,13 +2680,29 @@ dependencies = [ [[package]] name = "time" -version = "0.1.44" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9e6914ab8b1ae1c260a4ae7a49b6c5611b40328a735b21862567685e73255" +checksum = "a561bf4617eebd33bca6434b988f39ed798e527f51a1e797d0ee4f61c0a38376" dependencies = [ - "libc", - "wasi 0.10.0+wasi-snapshot-preview1", - "winapi", + "itoa", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" + +[[package]] +name = "time-macros" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d967f99f534ca7e495c575c62638eebc2898a8c84c119b89e250477bc4ba16b2" +dependencies = [ + "time-core", ] [[package]] @@ -2734,19 +2722,19 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.19.2" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" +checksum = "d76ce4a75fb488c605c54bf610f221cea8b0dafb53333c1a67e8ee199dcd2ae3" dependencies = [ + "autocfg", "bytes", "libc", "memchr", "mio", "num_cpus", - "once_cell", "pin-project-lite", "signal-hook-registry", - "socket2 0.4.4", + "socket2", "tokio-macros", "winapi", ] @@ -2787,9 +2775,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.9" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df54d54117d6fdc4e4fea40fe1e4e566b3505700e148a6827e59b34b0d2600d9" +checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce" dependencies = [ "futures-core", "pin-project-lite", @@ -2798,9 +2786,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc463cd8deddc3770d20f9852143d50bf6094e640b485cb2e189a2099085ff45" +checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" dependencies = [ "bytes", "futures-core", @@ -2860,9 +2848,9 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" [[package]] name = "tower-service" @@ -2872,11 +2860,11 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" -version = "0.1.35" +version = "0.1.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a400e31aa60b9d44a52a8ee0343b5b18566b03a8321e0d321f695cf56e940160" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "log", "pin-project-lite", "tracing-attributes", @@ -2885,9 +2873,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.21" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc6b8ad3567499f98a1db7a752b07a7c8c7c7c34c332ec00effb2b0027974b7c" +checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", @@ -2896,9 +2884,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.27" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7709595b8878a4965ce5e87ebf880a7d39c9afc6837721b21a5a816a8117d921" +checksum = "24eb03ba0eab1fd845050058ce5e616558e8f8d8fca633e6b163fe25c797213a" dependencies = [ "once_cell", "valuable", @@ -2906,9 +2894,9 @@ dependencies = [ [[package]] name = "tracing-flame" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd520fe41c667b437952383f3a1ec14f1fa45d653f719a77eedd6e6a02d8fa54" +checksum = "0bae117ee14789185e129aaee5d93750abe67fdc5a9a62650452bfe4e122a3a9" dependencies = [ "lazy_static", "tracing", @@ -2927,79 +2915,79 @@ dependencies = [ ] [[package]] -name = "tracing-serde" -version = "0.1.3" +name = "tracing-opentelemetry" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +checksum = "21ebb87a95ea13271332df069020513ab70bdb5637ca42d6e492dc3bbbad48de" dependencies = [ - "serde", + "once_cell", + "opentelemetry", + "tracing", "tracing-core", + "tracing-log", + "tracing-subscriber", ] [[package]] name = "tracing-subscriber" -version = "0.2.25" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" dependencies = [ - "ansi_term", - "chrono", - "lazy_static", "matchers", + "nu-ansi-term", + "once_cell", "regex", - "serde", - "serde_json", "sharded-slab", "smallvec", "thread_local", "tracing", "tracing-core", "tracing-log", - "tracing-serde", ] [[package]] name = "trust-dns-proto" -version = "0.20.4" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca94d4e9feb6a181c690c4040d7a24ef34018d8313ac5044a61d21222ae24e31" +checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26" dependencies = [ "async-trait", - "cfg-if 1.0.0", + "cfg-if", "data-encoding", "enum-as-inner", "futures-channel", "futures-io", "futures-util", - "idna", + "idna 0.2.3", "ipnet", "lazy_static", - "log", "rand 0.8.5", "smallvec", "thiserror", "tinyvec", "tokio", + "tracing", "url", ] [[package]] name = "trust-dns-resolver" -version = "0.20.4" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecae383baad9995efaa34ce8e57d12c3f305e545887472a492b838f4b5cfb77a" +checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "futures-util", "ipconfig", "lazy_static", - "log", "lru-cache", "parking_lot", "resolv-conf", "smallvec", "thiserror", "tokio", + "tracing", "trust-dns-proto", ] @@ -3032,24 +3020,30 @@ checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" [[package]] name = "unicode-ident" -version = "1.0.1" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bd2fe26506023ed7b5e1e315add59d6f584c621d037f9368fea9cfb988f368c" +checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" [[package]] name = "unicode-normalization" -version = "0.1.19" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9" +checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" dependencies = [ "tinyvec", ] [[package]] name = "unicode-xid" -version = "0.2.3" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + +[[package]] +name = "unsafe-libyaml" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "957e51f3646910546462e67d5f7599b9e4fb8acdd304b087a6494730f9eebf04" +checksum = "c1e5fa573d8ac5f1a856f8d7be41d390ee973daf97c806b2c1a465e4e1406e68" [[package]] name = "unsigned-varint" @@ -3065,23 +3059,22 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "url" -version = "2.2.2" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c" +checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" dependencies = [ "form_urlencoded", - "idna", - "matches", + "idna 0.3.0", "percent-encoding", ] [[package]] name = "uuid" -version = "0.8.2" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7" +checksum = "422ee0de9031b5b948b97a8fc04e3aa35230001a722ddd27943e0be31564ce4c" dependencies = [ - "getrandom 0.2.7", + "getrandom 0.2.8", ] [[package]] @@ -3120,35 +3113,29 @@ checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" [[package]] name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" - -[[package]] -name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.81" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c53b543413a17a202f4be280a7e5c62a1c69345f5de525ee64f8cfdbc954994" +checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.81" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5491a68ab4500fa6b4d726bd67408630c3dbe9c4fe7bda16d5c82a1fd8c7340a" +checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" dependencies = [ "bumpalo", - "lazy_static", "log", + "once_cell", "proc-macro2", "quote", "syn", @@ -3157,11 +3144,11 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.31" +version = "0.4.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de9a9cec1733468a8c657e57fa2413d2ae2c0129b95e87c5b72b8ace4d13f31f" +checksum = "23639446165ca5a5de86ae1d8896b737ae80319560fbaa4c2887b7da6e7ebd7d" dependencies = [ - "cfg-if 1.0.0", + "cfg-if", "js-sys", "wasm-bindgen", "web-sys", @@ -3169,9 +3156,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.81" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c441e177922bc58f1e12c022624b6216378e5febc2f0533e41ba443d505b80aa" +checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3179,9 +3166,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.81" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d94ac45fcf608c1f45ef53e748d35660f168490c10b23704c7779ab8f5c3048" +checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" dependencies = [ "proc-macro2", "quote", @@ -3192,15 +3179,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.81" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a89911bd99e5f3659ec4acf9c4d93b0a90fe4a2a11f15328472058edc5261be" +checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" [[package]] name = "web-sys" -version = "0.3.58" +version = "0.3.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fed94beee57daf8dd7d51f2b15dc2bcde92d7a72304cdf662a4371008b71b90" +checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3218,21 +3205,21 @@ dependencies = [ [[package]] name = "weezl" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c97e489d8f836838d497091de568cf16b117486d529ec5579233521065bd5e4" +checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" [[package]] name = "widestring" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c168940144dd21fd8046987c16a46a33d5fc84eec29ef9dcddc2ac9e31526b7c" +checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983" [[package]] name = "wildmatch" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c48bd20df7e4ced539c12f570f937c6b4884928a87fee70a479d72f031d4e0" +checksum = "ee583bdc5ff1cf9db20e9db5bb3ff4c3089a8f6b8b31aff265c9aba85812db86" [[package]] name = "winapi" @@ -3262,51 +3249,99 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", +] + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc 0.42.0", + "windows_i686_gnu 0.42.0", + "windows_i686_msvc 0.42.0", + "windows_x86_64_gnu 0.42.0", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc 0.42.0", ] [[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" + +[[package]] name = "windows_aarch64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" [[package]] +name = "windows_aarch64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" + +[[package]] name = "windows_i686_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" [[package]] +name = "windows_i686_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" + +[[package]] name = "windows_i686_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" [[package]] +name = "windows_i686_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" + +[[package]] name = "windows_x86_64_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" [[package]] +name = "windows_x86_64_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" + +[[package]] name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" [[package]] -name = "winreg" -version = "0.6.2" +name = "windows_x86_64_msvc" +version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2986deb581c4fe11b621998a5e53361efe6b48a151178d0cd9eeffa4dc6acc9" -dependencies = [ - "winapi", -] +checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" [[package]] name = "winreg" @@ -3318,12 +3353,12 @@ dependencies = [ ] [[package]] -name = "yaml-rust" -version = "0.4.5" +name = "winreg" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ - "linked-hash-map", + "winapi", ] [[package]] @@ -3334,9 +3369,9 @@ checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" [[package]] name = "zeroize" -version = "1.3.0" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4756f7db3f7b5574938c3eb1c117038b8e07f95ee6718c0efad4ac21508f1efd" +checksum = "c394b5bd0c6f669e7275d9c20aa90ae064cb22e75a1cad54e1b34088034b149f" dependencies = [ "zeroize_derive", ] @@ -3361,32 +3396,3 @@ checksum = "70b40401a28d86ce16a330b863b86fd7dbee4d7c940587ab09ab8c019f9e3fdf" dependencies = [ "num-traits", ] - -[[package]] -name = "zstd" -version = "0.9.2+zstd.1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2390ea1bf6c038c39674f22d95f0564725fc06034a47129179810b2fc58caa54" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "4.1.3+zstd.1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e99d81b99fb3c2c2c794e3fe56c305c63d5173a16a46b5850b07c935ffc7db79" -dependencies = [ - "libc", - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "1.6.2+zstd.1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2daf2f248d9ea44454bfcb2516534e8b8ad2fc91bf818a1885495fc42bc8ac9f" -dependencies = [ - "cc", - "libc", -] @@ -6,29 +6,29 @@ authors = ["timokoesters <timo@koesters.xyz>"] homepage = "https://conduit.rs" repository = "https://gitlab.com/famedly/conduit" readme = "README.md" -version = "0.4.0" -rust-version = "1.56" +version = "0.5.0" +rust-version = "1.64" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] # Web framework -axum = { version = "0.5.8", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true } +axum = { version = "0.5.17", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true } axum-server = { version = "0.4.0", features = ["tls-rustls"] } tower = { version = "0.4.8", features = ["util"] } tower-http = { version = "0.3.4", features = ["add-extension", "cors", "compression-full", "sensitive-headers", "trace", "util"] } # Used for matrix spec type definitions and helpers #ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } -ruma = { git = "https://github.com/ruma/ruma", rev = "d614ad1422d6c4b3437ebc318ca8514ae338fd6d", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-pre-spec", "unstable-exhaustive-types"] } +ruma = { git = "https://github.com/ruma/ruma", rev = "67d0f3cc04a8d1dc4a8a1ec947519967ce11ce26", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] } #ruma = { git = "https://github.com/timokoesters/ruma", rev = "50c1db7e0a3a21fc794b0cce3b64285a4c750c71", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } # Async runtime and utilities tokio = { version = "1.11.0", features = ["fs", "macros", "signal", "sync"] } # Used for storing data permanently -sled = { version = "0.34.7", features = ["compression", "no_metrics"], optional = true } +#sled = { version = "0.34.7", features = ["compression", "no_metrics"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } persy = { version = "1.0.0", optional = true, features = ["background_ops"] } @@ -40,38 +40,39 @@ directories = "4.0.0" # Used for ruma wrapper serde_json = { version = "1.0.68", features = ["raw_value"] } # Used for appservice registration files -serde_yaml = "0.8.21" +serde_yaml = "0.9.13" # Used for pdu definition serde = { version = "1.0.130", features = ["rc"] } # Used for secure identifiers rand = "0.8.4" # Used to hash passwords -rust-argon2 = "0.8.3" +rust-argon2 = "1.0.0" # Used to send requests reqwest = { default-features = false, features = ["rustls-tls-native-roots", "socks"], git = "https://github.com/timokoesters/reqwest", rev = "57b7cf4feb921573dfafad7d34b9ac6e44ead0bd" } # Used for conduit::Error type thiserror = "1.0.29" # Used to generate thumbnails for images -image = { version = "0.23.14", default-features = false, features = ["jpeg", "png", "gif"] } +image = { version = "0.24.4", default-features = false, features = ["jpeg", "png", "gif"] } # Used to encode server public key base64 = "0.13.0" # Used when hashing the state ring = "0.16.20" # Used when querying the SRV record of other servers -trust-dns-resolver = "0.20.3" +trust-dns-resolver = "0.22.0" # Used to find matching events for appservices regex = "1.5.4" # jwt jsonwebtokens -jsonwebtoken = "7.2.0" +jsonwebtoken = "8.1.1" # Performance measurements tracing = { version = "0.1.27", features = [] } -tracing-subscriber = "0.2.22" -tracing-flame = "0.1.0" -opentelemetry = { version = "0.16.0", features = ["rt-tokio"] } -opentelemetry-jaeger = { version = "0.15.0", features = ["rt-tokio"] } +tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +tracing-flame = "0.2.0" +opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } +opentelemetry-jaeger = { version = "0.17.0", features = ["rt-tokio"] } +tracing-opentelemetry = "0.18.0" lru-cache = "0.1.2" -rusqlite = { version = "0.25.3", optional = true, features = ["bundled"] } -parking_lot = { version = "0.11.2", optional = true } +rusqlite = { version = "0.28.0", optional = true, features = ["bundled"] } +parking_lot = { version = "0.12.1", optional = true } crossbeam = { version = "0.8.1", optional = true } num_cpus = "1.13.0" threadpool = "1.8.1" @@ -80,20 +81,24 @@ rocksdb = { version = "0.17.0", default-features = true, features = ["multi-thre thread_local = "1.1.3" # used for TURN server authentication -hmac = "0.11.0" -sha-1 = "0.9.8" +hmac = "0.12.1" +sha-1 = "0.10.0" # used for conduit's CLI and admin room command parsing -clap = { version = "3.2.5", default-features = false, features = ["std", "derive"] } +clap = { version = "4.0.11", default-features = false, features = ["std", "derive", "help", "usage", "error-context"] } futures-util = { version = "0.3.17", default-features = false } # Used for reading the configuration from conduit.toml & environment variables figment = { version = "0.10.6", features = ["env", "toml"] } -tikv-jemalloc-ctl = { version = "0.4.2", features = ["use_std"], optional = true } -tikv-jemallocator = { version = "0.4.1", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } +tikv-jemalloc-ctl = { version = "0.5.0", features = ["use_std"], optional = true } +tikv-jemallocator = { version = "0.5.0", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } +lazy_static = "1.4.0" +async-trait = "0.1.57" + +sd-notify = { version = "0.4.1", optional = true } [features] -default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "jemalloc"] -backend_sled = ["sled"] +default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "jemalloc", "systemd"] +#backend_sled = ["sled"] backend_persy = ["persy", "parking_lot"] backend_sqlite = ["sqlite"] backend_heed = ["heed", "crossbeam"] @@ -101,6 +106,7 @@ backend_rocksdb = ["rocksdb"] jemalloc = ["tikv-jemalloc-ctl", "tikv-jemallocator"] sqlite = ["rusqlite", "parking_lot", "tokio/signal"] conduit_bin = ["axum"] +systemd = ["sd-notify"] [[bin]] name = "conduit" @@ -2,7 +2,7 @@ > ## Getting help > -> If you run into any problems while setting up Conduit, write an email to `timo@koesters.xyz`, ask us +> If you run into any problems while setting up Conduit, write an email to `conduit@koesters.xyz`, ask us > in `#conduit:fachschaften.org` or [open an issue on GitLab](https://gitlab.com/famedly/conduit/-/issues/new). ## Installing Conduit @@ -12,21 +12,27 @@ only offer Linux binaries. You may simply download the binary that fits your machine. Run `uname -m` to see what you need. Now copy the right url: -| CPU Architecture | Download stable version | Download development version | -| ------------------------------------------- | ------------------------------ | ---------------------------- | -| x84_64 / amd64 (Most servers and computers) | [Download][x84_64-musl-master] | [Download][x84_64-musl-next] | -| armv6 | [Download][armv6-musl-master] | [Download][armv6-musl-next] | -| armv7 (e.g. Raspberry Pi by default) | [Download][armv7-musl-master] | [Download][armv7-musl-next] | -| armv8 / aarch64 | [Download][armv8-musl-master] | [Download][armv8-musl-next] | - -[x84_64-musl-master]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/conduit-x86_64-unknown-linux-musl?job=build:release:cargo:x86_64-unknown-linux-musl -[armv6-musl-master]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/conduit-arm-unknown-linux-musleabihf?job=build:release:cargo:arm-unknown-linux-musleabihf -[armv7-musl-master]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/conduit-armv7-unknown-linux-musleabihf?job=build:release:cargo:armv7-unknown-linux-musleabihf -[armv8-musl-master]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/conduit-aarch64-unknown-linux-musl?job=build:release:cargo:aarch64-unknown-linux-musl -[x84_64-musl-next]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/conduit-x86_64-unknown-linux-musl?job=build:release:cargo:x86_64-unknown-linux-musl -[armv6-musl-next]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/conduit-arm-unknown-linux-musleabihf?job=build:release:cargo:arm-unknown-linux-musleabihf -[armv7-musl-next]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/conduit-armv7-unknown-linux-musleabihf?job=build:release:cargo:armv7-unknown-linux-musleabihf -[armv8-musl-next]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/conduit-aarch64-unknown-linux-musl?job=build:release:cargo:aarch64-unknown-linux-musl +| CPU Architecture | Download stable version | Download development version | +| ------------------------------------------- | --------------------------------------------------------------- | ----------------------------------------------------------- | +| x84_64 / amd64 (Most servers and computers) | [Binary][x84_64-glibc-master] / [.deb][x84_64-glibc-master-deb] | [Binary][x84_64-glibc-next] / [.deb][x84_64-glibc-next-deb] | +| armv7 (e.g. Raspberry Pi by default) | [Binary][armv7-glibc-master] / [.deb][armv7-glibc-master-deb] | [Binary][armv7-glibc-next] / [.deb][armv7-glibc-next-deb] | +| armv8 / aarch64 | [Binary][armv8-glibc-master] / [.deb][armv8-glibc-master-deb] | [Binary][armv8-glibc-next] / [.deb][armv8-glibc-next-deb] | + +These builds were created on and linked against the glibc version shipped with Debian bullseye. +If you use a system with an older glibc version, you might need to compile Conduit yourself. + +[x84_64-glibc-master]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/build-output/linux_amd64/conduit?job=docker:master +[armv7-glibc-master]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/build-output/linux_arm_v7/conduit?job=docker:master +[armv8-glibc-master]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/build-output/linux_arm64/conduit?job=docker:master +[x84_64-glibc-next]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/build-output/linux_amd64/conduit?job=docker:next +[armv7-glibc-next]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/build-output/linux_arm_v7/conduit?job=docker:next +[armv8-glibc-next]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/build-output/linux_arm64/conduit?job=docker:next +[x84_64-glibc-master-deb]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/build-output/linux_amd64/conduit.deb?job=docker:master +[armv7-glibc-master-deb]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/build-output/linux_arm_v7/conduit.deb?job=docker:master +[armv8-glibc-master-deb]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/master/raw/build-output/linux_arm64/conduit.deb?job=docker:master +[x84_64-glibc-next-deb]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/build-output/linux_amd64/conduit.deb?job=docker:next +[armv7-glibc-next-deb]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/build-output/linux_arm_v7/conduit.deb?job=docker:next +[armv8-glibc-next-deb]: https://gitlab.com/famedly/conduit/-/jobs/artifacts/next/raw/build-output/linux_arm64/conduit.deb?job=docker:next ```bash $ sudo wget -O /usr/local/bin/matrix-conduit <url> @@ -43,8 +49,25 @@ $ sudo apt install libclang-dev build-essential $ cargo build --release ``` +If you want to cross compile Conduit to another architecture, read the guide below. -If you want to cross compile Conduit to another architecture, read the [Cross-Compile Guide](cross/README.md). +<details> +<summary>Cross compilation</summary> + +As easiest way to compile conduit for another platform [cross-rs](https://github.com/cross-rs/cross) is recommended, so install it first. + +In order to use RockDB as storage backend append `-latomic` to linker flags. + +For example, to build a binary for Raspberry Pi Zero W (ARMv6) you need `arm-unknown-linux-gnueabihf` as compilation +target. + +```bash +git clone https://gitlab.com/famedly/conduit.git +cd conduit +export RUSTFLAGS='-C link-arg=-lgcc -Clink-arg=-latomic -Clink-arg=-static-libgcc' +cross build --release --no-default-features --features conduit_bin,backend_rocksdb,jemalloc --target=arm-unknown-linux-gnueabihf +``` +</details> ## Adding a Conduit user @@ -136,7 +159,7 @@ allow_federation = true trusted_servers = ["matrix.org"] #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time -#log = "info,state_res=warn,rocket=off,_=off,sled=off" +#log = "warn,state_res=warn,rocket=off,_=off,sled=off" address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy #address = "0.0.0.0" # If Conduit is running in a container, make sure the reverse proxy (ie. Traefik) can reach it. @@ -189,18 +212,21 @@ $ sudo systemctl reload apache2 ``` ### Caddy + Create `/etc/caddy/conf.d/conduit_caddyfile` and enter this (substitute for your server name). + ```caddy your.server.name, your.server.name:8448 { reverse_proxy /_matrix/* 127.0.0.1:6167 } ``` + That's it! Just start or enable the service and you're set. + ```bash $ sudo systemctl enable caddy ``` - ### Nginx If you use Nginx and not Apache, add the following server section inside the http section of `/etc/nginx/nginx.conf` @@ -214,6 +240,9 @@ server { server_name your.server.name; # EDIT THIS merge_slashes off; + # Nginx defaults to only allow 1MB uploads + client_max_body_size 20M; + location /_matrix/ { proxy_pass http://127.0.0.1:6167$request_uri; proxy_set_header Host $http_host; @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM docker.io/rust:1.58-bullseye AS builder +FROM docker.io/rust:1.64-bullseye AS builder WORKDIR /usr/src/conduit # Install required packages to build Conduit and it's dependencies @@ -27,6 +27,49 @@ COPY src src # Builds conduit and places the binary at /usr/src/conduit/target/release/conduit RUN touch src/main.rs && touch src/lib.rs && cargo build --release + +# ONLY USEFUL FOR CI: target stage to extract build artifacts +FROM scratch AS builder-result +COPY --from=builder /usr/src/conduit/target/release/conduit /conduit + + + +# --------------------------------------------------------------------------------------------------------------- +# Build cargo-deb, a tool to package up rust binaries into .deb packages for Debian/Ubuntu based systems: +# --------------------------------------------------------------------------------------------------------------- +FROM docker.io/rust:1.64-bullseye AS build-cargo-deb + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + dpkg \ + dpkg-dev \ + liblzma-dev + +RUN cargo install cargo-deb +# => binary is in /usr/local/cargo/bin/cargo-deb + + +# --------------------------------------------------------------------------------------------------------------- +# Package conduit build-result into a .deb package: +# --------------------------------------------------------------------------------------------------------------- +FROM builder AS packager +WORKDIR /usr/src/conduit + +COPY ./LICENSE ./LICENSE +COPY ./README.md ./README.md +COPY debian/README.Debian ./debian/ +COPY --from=build-cargo-deb /usr/local/cargo/bin/cargo-deb /usr/local/cargo/bin/cargo-deb + +# --no-build makes cargo-deb reuse already compiled project +RUN cargo deb --no-build +# => Package is in /usr/src/conduit/target/debian/<project_name>_<version>_<arch>.deb + + +# ONLY USEFUL FOR CI: target stage to extract build artifacts +FROM scratch AS packager-result +COPY --from=packager /usr/src/conduit/target/debian/*.deb /conduit.deb + + # --------------------------------------------------------------------------------------------------------------- # Stuff below this line actually ends up in the resulting docker image # --------------------------------------------------------------------------------------------------------------- @@ -36,30 +79,32 @@ FROM docker.io/debian:bullseye-slim AS runner # You still need to map the port when using the docker command or docker-compose. EXPOSE 6167 +ARG DEFAULT_DB_PATH=/var/lib/matrix-conduit + ENV CONDUIT_PORT=6167 \ CONDUIT_ADDRESS="0.0.0.0" \ - CONDUIT_DATABASE_PATH=/var/lib/matrix-conduit \ + CONDUIT_DATABASE_PATH=${DEFAULT_DB_PATH} \ CONDUIT_CONFIG='' # └─> Set no config file to do all configuration with env vars # Conduit needs: +# dpkg: to install conduit.deb # ca-certificates: for https # iproute2 & wget: for the healthcheck script RUN apt-get update && apt-get -y --no-install-recommends install \ + dpkg \ ca-certificates \ iproute2 \ wget \ && rm -rf /var/lib/apt/lists/* -# Created directory for the database and media files -RUN mkdir -p /srv/conduit/.local/share/conduit - # Test if Conduit is still alive, uses the same endpoint as Element COPY ./docker/healthcheck.sh /srv/conduit/healthcheck.sh HEALTHCHECK --start-period=5s --interval=5s CMD ./healthcheck.sh -# Copy over the actual Conduit binary from the builder stage -COPY --from=builder /usr/src/conduit/target/release/conduit /srv/conduit/conduit +# Install conduit.deb: +COPY --from=packager /usr/src/conduit/target/debian/*.deb /srv/conduit/ +RUN dpkg -i /srv/conduit/*.deb # Improve security: Don't run stuff as root, that does not need to run as root # Most distros also use 1000:1000 for the first real user, so this should resolve volume mounting problems. @@ -69,9 +114,11 @@ RUN set -x ; \ groupadd -r -g ${GROUP_ID} conduit ; \ useradd -l -r -M -d /srv/conduit -o -u ${USER_ID} -g conduit conduit && exit 0 ; exit 1 -# Change ownership of Conduit files to conduit user and group and make the healthcheck executable: +# Create database directory, change ownership of Conduit files to conduit user and group and make the healthcheck executable: RUN chown -cR conduit:conduit /srv/conduit && \ - chmod +x /srv/conduit/healthcheck.sh + chmod +x /srv/conduit/healthcheck.sh && \ + mkdir -p ${DEFAULT_DB_PATH} && \ + chown -cR conduit:conduit ${DEFAULT_DB_PATH} # Change user to conduit, no root permissions afterwards: USER conduit @@ -80,4 +127,4 @@ WORKDIR /srv/conduit # Run Conduit and print backtraces on panics ENV RUST_BACKTRACE=1 -ENTRYPOINT [ "/srv/conduit/conduit" ] +ENTRYPOINT [ "/usr/sbin/matrix-conduit" ] @@ -1,7 +1,12 @@ # Conduit - ### A Matrix homeserver written in Rust +#### What is Matrix? +[Matrix](https://matrix.org) is an open network for secure and decentralized +communication. Users from every Matrix homeserver can chat with users from all +other Matrix servers. You can even use bridges (also called Matrix appservices) +to communicate with users outside of Matrix, like a community on Discord. + #### What is the goal? An efficient Matrix homeserver that's easy to set up and just works. You can install @@ -13,9 +18,10 @@ friends or company. Yes! You can test our Conduit instance by opening a Matrix client (<https://app.element.io> or Element Android for example) and registering on the `conduit.rs` homeserver. -It is hosted on a ODROID HC 2 with 2GB RAM and a SAMSUNG Exynos 5422 CPU, which -was used in the Samsung Galaxy S5. It joined many big rooms including Matrix -HQ. +*Registration is currently disabled because of scammers. For an account please + message us (see contact section below).* + +Server hosting for conduit.rs is donated by the Matrix.org Foundation. #### What is the current status? @@ -25,8 +31,8 @@ from time to time. There are still a few important features missing: -- E2EE verification over federation -- Outgoing read receipts, typing, presence over federation +- E2EE emoji comparison over federation (E2EE chat works) +- Outgoing read receipts, typing, presence over federation (incoming works) Check out the [Conduit 1.0 Release Milestone](https://gitlab.com/famedly/conduit/-/milestones/3). @@ -34,6 +40,7 @@ Check out the [Conduit 1.0 Release Milestone](https://gitlab.com/famedly/conduit - Simple install (this was tested the most): [DEPLOY.md](DEPLOY.md) - Debian package: [debian/README.Debian](debian/README.Debian) +- Nix/NixOS: [nix/README.md](nix/README.md) - Docker: [docker/README.md](docker/README.md) If you want to connect an Appservice to Conduit, take a look at [APPSERVICES.md](APPSERVICES.md). @@ -49,13 +56,21 @@ If you want to connect an Appservice to Conduit, take a look at [APPSERVICES.md] #### Thanks to -Thanks to Famedly, Prototype Fund (DLR and German BMBF) and all other individuals for financially supporting this project. +Thanks to FUTO, Famedly, Prototype Fund (DLR and German BMBF) and all individuals for financially supporting this project. Thanks to the contributors to Conduit and all libraries we use, for example: - Ruma: A clean library for the Matrix Spec in Rust - axum: A modular web framework +#### Contact + +If you run into any question, feel free to +- Ask us in `#conduit:fachschaften.org` on Matrix +- Write an E-Mail to `conduit@koesters.xyz` +- Send an direct message to `timo@fachschaften.org` on Matrix +- [Open an issue on GitLab](https://gitlab.com/famedly/conduit/-/issues/new) + #### Donate Liberapay: <https://liberapay.com/timokoesters/>\ diff --git a/conduit-example.toml b/conduit-example.toml index 5eed070..0549030 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -46,7 +46,7 @@ enable_lightning_bolt = true trusted_servers = ["matrix.org"] #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time -#log = "info,state_res=warn,rocket=off,_=off,sled=off" +#log = "warn,state_res=warn,rocket=off,_=off,sled=off" address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy #address = "0.0.0.0" # If Conduit is running in a container, make sure the reverse proxy (ie. Traefik) can reach it. diff --git a/debian/postinst b/debian/postinst index aab2480..73e554b 100644 --- a/debian/postinst +++ b/debian/postinst @@ -77,7 +77,7 @@ allow_federation = true trusted_servers = ["matrix.org"] #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time -#log = "info,state_res=warn,rocket=off,_=off,sled=off" +#log = "warn,state_res=warn,rocket=off,_=off,sled=off" EOF fi ;; diff --git a/docker-compose.yml b/docker-compose.yml index 0a9d8f4..d9c32b5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -31,7 +31,7 @@ services: CONDUIT_ALLOW_FEDERATION: 'true' CONDUIT_TRUSTED_SERVERS: '["matrix.org"]' #CONDUIT_MAX_CONCURRENT_REQUESTS: 100 - #CONDUIT_LOG: info,rocket=off,_=off,sled=off + #CONDUIT_LOG: warn,rocket=off,_=off,sled=off CONDUIT_ADDRESS: 0.0.0.0 CONDUIT_CONFIG: '' # Ignore this # diff --git a/docker/README.md b/docker/README.md index c980adc..c702832 100644 --- a/docker/README.md +++ b/docker/README.md @@ -33,7 +33,7 @@ docker run -d -p 8448:6167 \ -e CONDUIT_MAX_REQUEST_SIZE="20_000_000" \ -e CONDUIT_TRUSTED_SERVERS="[\"matrix.org\"]" \ -e CONDUIT_MAX_CONCURRENT_REQUESTS="100" \ - -e CONDUIT_LOG="info,rocket=off,_=off,sled=off" \ + -e CONDUIT_LOG="warn,rocket=off,_=off,sled=off" \ --name conduit matrixconduit/matrix-conduit:latest ``` @@ -121,12 +121,12 @@ So...step by step: location /.well-known/matrix/server { return 200 '{"m.server": "<SUBDOMAIN>.<DOMAIN>:443"}'; - add_header Content-Type application/json; + types { } default_type "application/json; charset=utf-8"; } location /.well-known/matrix/client { return 200 '{"m.homeserver": {"base_url": "https://<SUBDOMAIN>.<DOMAIN>"}}'; - add_header Content-Type application/json; + types { } default_type "application/json; charset=utf-8"; add_header "Access-Control-Allow-Origin" *; } diff --git a/docker/docker-compose.for-traefik.yml b/docker/docker-compose.for-traefik.yml index ca560b8..474299f 100644 --- a/docker/docker-compose.for-traefik.yml +++ b/docker/docker-compose.for-traefik.yml @@ -31,7 +31,7 @@ services: CONDUIT_ALLOW_FEDERATION: 'true' CONDUIT_TRUSTED_SERVERS: '["matrix.org"]' #CONDUIT_MAX_CONCURRENT_REQUESTS: 100 - #CONDUIT_LOG: info,rocket=off,_=off,sled=off + #CONDUIT_LOG: warn,rocket=off,_=off,sled=off CONDUIT_ADDRESS: 0.0.0.0 CONDUIT_CONFIG: '' # Ignore this diff --git a/docker/docker-compose.with-traefik.yml b/docker/docker-compose.with-traefik.yml index 6d46827..79ebef4 100644 --- a/docker/docker-compose.with-traefik.yml +++ b/docker/docker-compose.with-traefik.yml @@ -33,7 +33,7 @@ services: # CONDUIT_PORT: 6167 # CONDUIT_CONFIG: '/srv/conduit/conduit.toml' # if you want to configure purely by env vars, set this to an empty string '' # Available levels are: error, warn, info, debug, trace - more info at: https://docs.rs/env_logger/*/env_logger/#enabling-logging - # CONDUIT_LOG: info # default is: "info,_=off,sled=off" + # CONDUIT_LOG: info # default is: "warn,_=off,sled=off" # CONDUIT_ALLOW_JAEGER: 'false' # CONDUIT_ALLOW_ENCRYPTION: 'false' # CONDUIT_ALLOW_FEDERATION: 'false' diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..9217ff2 --- /dev/null +++ b/flake.lock @@ -0,0 +1,102 @@ +{ + "nodes": { + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1665815894, + "narHash": "sha256-Vboo1L4NMGLKZKVLnOPi9OHlae7uoNyfgvyIUm+SVXE=", + "owner": "nix-community", + "repo": "fenix", + "rev": "2348450241a5f945f0ba07e44ecbfac2f541d7f4", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "type": "github" + } + }, + "flake-utils": { + "locked": { + "lastModified": 1659877975, + "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "naersk": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1662220400, + "narHash": "sha256-9o2OGQqu4xyLZP9K6kNe1pTHnyPz0Wr3raGYnr9AIgY=", + "owner": "nix-community", + "repo": "naersk", + "rev": "6944160c19cb591eb85bbf9b2f2768a935623ed3", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "naersk", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1665856037, + "narHash": "sha256-/RvIWnGKdTSoIq5Xc2HwPIL0TzRslzU6Rqk4Img6UNg=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "c95ebc5125ffffcd431df0ad8620f0926b8125b8", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "fenix": "fenix", + "flake-utils": "flake-utils", + "naersk": "naersk", + "nixpkgs": "nixpkgs" + } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1665765556, + "narHash": "sha256-w9L5j0TIB5ay4aRwzGCp8mgvGsu5dVJQvbEFutwr6xE=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "018b8429cf3fa9d8aed916704e41dfedeb0f4f78", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..924300c --- /dev/null +++ b/flake.nix @@ -0,0 +1,75 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs"; + flake-utils.url = "github:numtide/flake-utils"; + + fenix = { + url = "github:nix-community/fenix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + naersk = { + url = "github:nix-community/naersk"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = + { self + , nixpkgs + , flake-utils + + , fenix + , naersk + }: flake-utils.lib.eachDefaultSystem (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + + # Nix-accessible `Cargo.toml` + cargoToml = builtins.fromTOML (builtins.readFile ./Cargo.toml); + + # The Rust toolchain to use + toolchain = fenix.packages.${system}.toolchainOf { + # Use the Rust version defined in `Cargo.toml` + channel = cargoToml.package.rust-version; + + # This will need to be updated when `package.rust-version` is changed in + # `Cargo.toml` + sha256 = "sha256-KXx+ID0y4mg2B3LHp7IyaiMrdexF6octADnAtFIOjrY="; + }; + + builder = (pkgs.callPackage naersk { + inherit (toolchain) rustc cargo; + }).buildPackage; + in + { + packages.default = builder { + src = ./.; + + nativeBuildInputs = (with pkgs.rustPlatform; [ + bindgenHook + ]); + }; + + devShells.default = pkgs.mkShell { + # Rust Analyzer needs to be able to find the path to default crate + # sources, and it can read this environment variable to do so + RUST_SRC_PATH = "${toolchain.rust-src}/lib/rustlib/src/rust/library"; + + # Development tools + nativeBuildInputs = (with pkgs.rustPlatform; [ + bindgenHook + ]) ++ (with toolchain; [ + cargo + clippy + rust-src + rustc + rustfmt + ]); + }; + + checks = { + packagesDefault = self.packages.${system}.default; + devShellsDefault = self.devShells.${system}.default; + }; + }); +} diff --git a/nix/README.md b/nix/README.md new file mode 100644 index 0000000..d92f910 --- /dev/null +++ b/nix/README.md @@ -0,0 +1,188 @@ +# Conduit for Nix/NixOS + +This guide assumes you have a recent version of Nix (^2.4) installed. + +Since Conduit ships as a Nix flake, you'll first need to [enable +flakes][enable_flakes]. + +You can now use the usual Nix commands to interact with Conduit's flake. For +example, `nix run gitlab:famedly/conduit` will run Conduit (though you'll need +to provide configuration and such manually as usual). + +If your NixOS configuration is defined as a flake, you can depend on this flake +to provide a more up-to-date version than provided by `nixpkgs`. In your flake, +add the following to your `inputs`: + +```nix +conduit = { + url = "gitlab:famedly/conduit"; + + # Assuming you have an input for nixpkgs called `nixpkgs`. If you experience + # build failures while using this, try commenting/deleting this line. This + # will probably also require you to always build from source. + inputs.nixpkgs.follows = "nixpkgs"; +}; +``` + +Next, make sure you're passing your flake inputs to the `specialArgs` argument +of `nixpkgs.lib.nixosSystem` [as explained here][specialargs]. This guide will +assume you've named the group `flake-inputs`. + +Now you can configure Conduit and a reverse proxy for it. Add the following to +a new Nix file and include it in your configuration: + +```nix +{ config +, pkgs +, flake-inputs +, ... +}: + +let + # You'll need to edit these values + + # The hostname that will appear in your user and room IDs + server_name = "example.com"; + + # The hostname that Conduit actually runs on + # + # This can be the same as `server_name` if you want. This is only necessary + # when Conduit is running on a different machine than the one hosting your + # root domain. This configuration also assumes this is all running on a single + # machine, some tweaks will need to be made if this is not the case. + matrix_hostname = "matrix.${server_name}"; + + # An admin email for TLS certificate notifications + admin_email = "admin@${server_name}"; + + # These ones you can leave alone + + # Build a dervation that stores the content of `${server_name}/.well-known/matrix/server` + well_known_server = pkgs.writeText "well-known-matrix-server" '' + { + "m.server": "${matrix_hostname}" + } + ''; + + # Build a dervation that stores the content of `${server_name}/.well-known/matrix/client` + well_known_client = pkgs.writeText "well-known-matrix-client" '' + { + "m.homeserver": { + "base_url": "https://${matrix_hostname}" + } + } + ''; +in + +{ + # Configure Conduit itself + services.matrix-conduit = { + enable = true; + + # This causes NixOS to use the flake defined in this repository instead of + # the build of Conduit built into nixpkgs. + package = flake-inputs.conduit.packages.${pkgs.system}.default; + + settings.global = { + inherit server_name; + }; + }; + + # Configure automated TLS acquisition/renewal + security.acme = { + acceptTerms = true; + defaults = { + email = admin_email; + }; + }; + + # ACME data must be readable by the NGINX user + users.users.nginx.extraGroups = [ + "acme" + ]; + + # Configure NGINX as a reverse proxy + services.nginx = { + enable = true; + recommendedProxySettings = true; + + virtualHosts = { + "${server_name}" = { + forceSSL = true; + enableACME = true; + + listen = [ + { + addr = "0.0.0.0"; + port = 443; + ssl = true; + } + { + addr = "0.0.0.0"; + port = 8448; + ssl = true; + } + ]; + + extraConfig = '' + merge_slashes off; + ''; + + "${matrix_hostname}" = { + forceSSL = true; + enableACME = true; + + locations."/_matrix/" = { + proxyPass = "http://backend_conduit$request_uri"; + proxyWebsockets = true; + extraConfig = '' + proxy_set_header Host $host; + proxy_buffering off; + ''; + }; + + locations."=/.well-known/matrix/server" = { + # Use the contents of the derivation built previously + alias = "${well_known_server}"; + + extraConfig = '' + # Set the header since by default NGINX thinks it's just bytes + default_type application/json; + ''; + }; + + locations."=/.well-known/matrix/client" = { + # Use the contents of the derivation built previously + alias = "${well_known_client}"; + + extraConfig = '' + # Set the header since by default NGINX thinks it's just bytes + default_type application/json; + + # https://matrix.org/docs/spec/client_server/r0.4.0#web-browser-clients + add_header Access-Control-Allow-Origin "*"; + ''; + }; + }; + }; + + upstreams = { + "backend_conduit" = { + servers = { + "localhost:${toString config.services.matrix-conduit.settings.global.port}" = { }; + }; + }; + }; + }; + + # Open firewall ports for HTTP, HTTPS, and Matrix federation + networking.firewall.allowedTCPPorts = [ 80 443 8448 ]; + networking.firewall.allowedUDPPorts = [ 80 443 8448 ]; +} +``` + +Now you can rebuild your system configuration and you should be good to go! + +[enable_flakes]: https://nixos.wiki/wiki/Flakes#Enable_flakes + +[specialargs]: https://nixos.wiki/wiki/Flakes#Using_nix_flakes_with_NixOS diff --git a/src/appservice_server.rs b/src/api/appservice_server.rs index ce122da..dc319e2 100644 --- a/src/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -1,12 +1,11 @@ -use crate::{utils, Error, Result}; +use crate::{services, utils, Error, Result}; use bytes::BytesMut; use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; use std::{fmt::Debug, mem, time::Duration}; use tracing::warn; -#[tracing::instrument(skip(globals, request))] +#[tracing::instrument(skip(request))] pub(crate) async fn send_request<T: OutgoingRequest>( - globals: &crate::database::globals::Globals, registration: serde_yaml::Value, request: T, ) -> Result<T::IncomingResponse> @@ -46,7 +45,23 @@ where *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); let url = reqwest_request.url().clone(); - let mut response = globals.default_client().execute(reqwest_request).await?; + let mut response = match services() + .globals + .default_client() + .execute(reqwest_request) + .await + { + Ok(r) => r, + Err(e) => { + warn!( + "Could not send request to appservice {:?} at {}: {}", + registration.get("id"), + destination, + e + ); + return Err(e.into()); + } + }; // reqwest::Response -> http::Response conversion let status = response.status(); diff --git a/src/client_server/account.rs b/src/api/client_server/account.rs index 1484bf6..7459254 100644 --- a/src/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -1,13 +1,11 @@ use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{ - database::{admin::make_user_admin, DatabaseGuard}, - utils, Error, Result, Ruma, -}; +use crate::{api::client_server, services, utils, Error, Result, Ruma}; use ruma::{ api::client::{ account::{ - change_password, deactivate, get_3pids, get_username_availability, register, whoami, - ThirdPartyIdRemovalStatus, + change_password, deactivate, get_3pids, get_username_availability, register, + request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, + whoami, ThirdPartyIdRemovalStatus, }, error::ErrorKind, uiaa::{AuthFlow, AuthType, UiaaInfo}, @@ -32,23 +30,24 @@ const RANDOM_USER_ID_LENGTH: usize = 10; /// /// Note: This will not reserve the username, so the username might become invalid when trying to register pub async fn get_register_available_route( - db: DatabaseGuard, - body: Ruma<get_username_availability::v3::IncomingRequest>, + body: Ruma<get_username_availability::v3::Request>, ) -> Result<get_username_availability::v3::Response> { // Validate user id - let user_id = - UserId::parse_with_server_name(body.username.to_lowercase(), db.globals.server_name()) - .ok() - .filter(|user_id| { - !user_id.is_historical() && user_id.server_name() == db.globals.server_name() - }) - .ok_or(Error::BadRequest( - ErrorKind::InvalidUsername, - "Username is invalid.", - ))?; + let user_id = UserId::parse_with_server_name( + body.username.to_lowercase(), + services().globals.server_name(), + ) + .ok() + .filter(|user_id| { + !user_id.is_historical() && user_id.server_name() == services().globals.server_name() + }) + .ok_or(Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ))?; // Check if username is creative enough - if db.users.exists(&user_id)? { + if services().users.exists(&user_id)? { return Err(Error::BadRequest( ErrorKind::UserInUse, "Desired user ID is already taken.", @@ -74,11 +73,8 @@ pub async fn get_register_available_route( /// - If type is not guest and no username is given: Always fails after UIAA check /// - Creates a new account and populates it with default account data /// - If `inhibit_login` is false: Creates a device and returns device id and access_token -pub async fn register_route( - db: DatabaseGuard, - body: Ruma<register::v3::IncomingRequest>, -) -> Result<register::v3::Response> { - if !db.globals.allow_registration() && !body.from_appservice { +pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<register::v3::Response> { + if !services().globals.allow_registration() && !body.from_appservice { return Err(Error::BadRequest( ErrorKind::Forbidden, "Registration has been disabled.", @@ -89,18 +85,20 @@ pub async fn register_route( let user_id = match (&body.username, is_guest) { (Some(username), false) => { - let proposed_user_id = - UserId::parse_with_server_name(username.to_lowercase(), db.globals.server_name()) - .ok() - .filter(|user_id| { - !user_id.is_historical() - && user_id.server_name() == db.globals.server_name() - }) - .ok_or(Error::BadRequest( - ErrorKind::InvalidUsername, - "Username is invalid.", - ))?; - if db.users.exists(&proposed_user_id)? { + let proposed_user_id = UserId::parse_with_server_name( + username.to_lowercase(), + services().globals.server_name(), + ) + .ok() + .filter(|user_id| { + !user_id.is_historical() + && user_id.server_name() == services().globals.server_name() + }) + .ok_or(Error::BadRequest( + ErrorKind::InvalidUsername, + "Username is invalid.", + ))?; + if services().users.exists(&proposed_user_id)? { return Err(Error::BadRequest( ErrorKind::UserInUse, "Desired user ID is already taken.", @@ -111,10 +109,10 @@ pub async fn register_route( _ => loop { let proposed_user_id = UserId::parse_with_server_name( utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), - db.globals.server_name(), + services().globals.server_name(), ) .unwrap(); - if !db.users.exists(&proposed_user_id)? { + if !services().users.exists(&proposed_user_id)? { break proposed_user_id; } }, @@ -133,14 +131,12 @@ pub async fn register_route( if !body.from_appservice { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - &UserId::parse_with_server_name("", db.globals.server_name()) + let (worked, uiaainfo) = services().uiaa.try_auth( + &UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid"), "".into(), auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -148,8 +144,8 @@ pub async fn register_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa.create( - &UserId::parse_with_server_name("", db.globals.server_name()) + services().uiaa.create( + &UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid"), "".into(), &uiaainfo, @@ -168,30 +164,31 @@ pub async fn register_route( }; // Create user - db.users.create(&user_id, password)?; + services().users.create(&user_id, password)?; // Default to pretty displayname let mut displayname = user_id.localpart().to_owned(); // If enabled append lightning bolt to display name (default true) - if db.globals.enable_lightning_bolt() { + if services().globals.enable_lightning_bolt() { displayname.push_str(" ⚡️"); } - db.users + services() + .users .set_displayname(&user_id, Some(displayname.clone()))?; // Initial account data - db.account_data.update( + services().account_data.update( None, &user_id, GlobalAccountDataEventType::PushRules.to_string().into(), - &ruma::events::push_rules::PushRulesEvent { + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { content: ruma::events::push_rules::PushRulesEventContent { global: push::Ruleset::server_default(&user_id), }, - }, - &db.globals, + }) + .expect("to json always works"), )?; // Inhibit login does not work for guests @@ -200,6 +197,8 @@ pub async fn register_route( access_token: None, user_id, device_id: None, + refresh_token: None, + expires_in: None, }); } @@ -215,7 +214,7 @@ pub async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - db.users.create_device( + services().users.create_device( &user_id, &device_id, &token, @@ -223,26 +222,29 @@ pub async fn register_route( )?; info!("New user {} registered on this server.", user_id); - db.admin + services() + .admin .send_message(RoomMessageEventContent::notice_plain(format!( - "New user {} registered on this server.", - user_id + "New user {user_id} registered on this server." ))); // If this is the first real user, grant them admin privileges // Note: the server user, @conduit:servername, is generated first - if db.users.count()? == 2 { - make_user_admin(&db, &user_id, displayname).await?; + if services().users.count()? == 2 { + services() + .admin + .make_user_admin(&user_id, displayname) + .await?; warn!("Granting {} admin privileges as the first user", user_id); } - db.flush()?; - Ok(register::v3::Response { access_token: Some(token), user_id, device_id: Some(device_id), + refresh_token: None, + expires_in: None, }) } @@ -261,8 +263,7 @@ pub async fn register_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn change_password_route( - db: DatabaseGuard, - body: Ruma<change_password::v3::IncomingRequest>, + body: Ruma<change_password::v3::Request>, ) -> Result<change_password::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -278,49 +279,45 @@ pub async fn change_password_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - sender_user, - sender_device, - auth, - &uiaainfo, - &db.users, - &db.globals, - )?; + let (worked, uiaainfo) = + services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services() + .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - db.users + services() + .users .set_password(sender_user, Some(&body.new_password))?; if body.logout_devices { // Logout all devices except the current one - for id in db + for id in services() .users .all_device_ids(sender_user) .filter_map(|id| id.ok()) .filter(|id| id != sender_device) { - db.users.remove_device(sender_user, &id)?; + services().users.remove_device(sender_user, &id)?; } } - db.flush()?; - info!("User {} changed their password.", sender_user); - db.admin + services() + .admin .send_message(RoomMessageEventContent::notice_plain(format!( - "User {} changed their password.", - sender_user + "User {sender_user} changed their password." ))); Ok(change_password::v3::Response {}) @@ -331,17 +328,14 @@ pub async fn change_password_route( /// Get user_id of the sender user. /// /// Note: Also works for Application Services -pub async fn whoami_route( - db: DatabaseGuard, - body: Ruma<whoami::v3::Request>, -) -> Result<whoami::v3::Response> { +pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let device_id = body.sender_device.as_ref().cloned(); Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: db.users.is_deactivated(&sender_user)?, + is_guest: services().users.is_deactivated(sender_user)? && !body.from_appservice, }) } @@ -356,8 +350,7 @@ pub async fn whoami_route( /// - Triggers device list updates /// - Removes ability to log in again pub async fn deactivate_route( - db: DatabaseGuard, - body: Ruma<deactivate::v3::IncomingRequest>, + body: Ruma<deactivate::v3::Request>, ) -> Result<deactivate::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -373,21 +366,18 @@ pub async fn deactivate_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - sender_user, - sender_device, - auth, - &uiaainfo, - &db.users, - &db.globals, - )?; + let (worked, uiaainfo) = + services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services() + .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -395,26 +385,24 @@ pub async fn deactivate_route( } // Make the user leave all rooms before deactivation - db.rooms.leave_all_rooms(&sender_user, &db).await?; + client_server::leave_all_rooms(sender_user).await?; // Remove devices and mark account as deactivated - db.users.deactivate_account(sender_user)?; + services().users.deactivate_account(sender_user)?; info!("User {} deactivated their account.", sender_user); - db.admin + services() + .admin .send_message(RoomMessageEventContent::notice_plain(format!( - "User {} deactivated their account.", - sender_user + "User {sender_user} deactivated their account." ))); - db.flush()?; - Ok(deactivate::v3::Response { id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, }) } -/// # `GET _matrix/client/r0/account/3pid` +/// # `GET _matrix/client/v3/account/3pid` /// /// Get a list of third party identifiers associated with this account. /// @@ -426,3 +414,31 @@ pub async fn third_party_route( Ok(get_3pids::v3::Response::new(Vec::new())) } + +/// # `POST /_matrix/client/v3/account/3pid/email/requestToken` +/// +/// "This API should be used to request validation tokens when adding an email address to an account" +/// +/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. +pub async fn request_3pid_management_token_via_email_route( + _body: Ruma<request_3pid_management_token_via_email::v3::Request>, +) -> Result<request_3pid_management_token_via_email::v3::Response> { + Err(Error::BadRequest( + ErrorKind::ThreepidDenied, + "Third party identifier is not allowed", + )) +} + +/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken` +/// +/// "This API should be used to request validation tokens when adding an phone number to an account" +/// +/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. +pub async fn request_3pid_management_token_via_msisdn_route( + _body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>, +) -> Result<request_3pid_management_token_via_msisdn::v3::Response> { + Err(Error::BadRequest( + ErrorKind::ThreepidDenied, + "Third party identifier is not allowed", + )) +} diff --git a/src/client_server/alias.rs b/src/api/client_server/alias.rs index 90e9d2c..ab51b50 100644 --- a/src/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use regex::Regex; use ruma::{ api::{ @@ -9,31 +9,35 @@ use ruma::{ }, federation, }, - RoomAliasId, + OwnedRoomAliasId, }; /// # `PUT /_matrix/client/r0/directory/room/{roomAlias}` /// /// Creates a new room alias on this server. pub async fn create_alias_route( - db: DatabaseGuard, - body: Ruma<create_alias::v3::IncomingRequest>, + body: Ruma<create_alias::v3::Request>, ) -> Result<create_alias::v3::Response> { - if body.room_alias.server_name() != db.globals.server_name() { + if body.room_alias.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Alias is from another server.", )); } - if db.rooms.id_from_alias(&body.room_alias)?.is_some() { + if services() + .rooms + .alias + .resolve_local_alias(&body.room_alias)? + .is_some() + { return Err(Error::Conflict("Alias already exists.")); } - db.rooms - .set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; - - db.flush()?; + services() + .rooms + .alias + .set_alias(&body.room_alias, &body.room_id)?; Ok(create_alias::v3::Response::new()) } @@ -45,22 +49,19 @@ pub async fn create_alias_route( /// - TODO: additional access control checks /// - TODO: Update canonical alias event pub async fn delete_alias_route( - db: DatabaseGuard, - body: Ruma<delete_alias::v3::IncomingRequest>, + body: Ruma<delete_alias::v3::Request>, ) -> Result<delete_alias::v3::Response> { - if body.room_alias.server_name() != db.globals.server_name() { + if body.room_alias.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Alias is from another server.", )); } - db.rooms.set_alias(&body.room_alias, None, &db.globals)?; + services().rooms.alias.remove_alias(&body.room_alias)?; // TODO: update alt_aliases? - db.flush()?; - Ok(delete_alias::v3::Response::new()) } @@ -70,23 +71,22 @@ pub async fn delete_alias_route( /// /// - TODO: Suggest more servers to join via pub async fn get_alias_route( - db: DatabaseGuard, - body: Ruma<get_alias::v3::IncomingRequest>, + body: Ruma<get_alias::v3::Request>, ) -> Result<get_alias::v3::Response> { - get_alias_helper(&db, &body.room_alias).await + get_alias_helper(body.body.room_alias).await } pub(crate) async fn get_alias_helper( - db: &Database, - room_alias: &RoomAliasId, + room_alias: OwnedRoomAliasId, ) -> Result<get_alias::v3::Response> { - if room_alias.server_name() != db.globals.server_name() { - let response = db + if room_alias.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, room_alias.server_name(), - federation::query::get_room_information::v1::Request { room_alias }, + federation::query::get_room_information::v1::Request { + room_alias: room_alias.to_owned(), + }, ) .await?; @@ -97,10 +97,10 @@ pub(crate) async fn get_alias_helper( } let mut room_id = None; - match db.rooms.id_from_alias(room_alias)? { + match services().rooms.alias.resolve_local_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - for (_id, registration) in db.appservice.all()? { + for (_id, registration) in services().appservice.all()? { let aliases = registration .get("namespaces") .and_then(|ns| ns.get("aliases")) @@ -115,19 +115,26 @@ pub(crate) async fn get_alias_helper( if aliases .iter() .any(|aliases| aliases.is_match(room_alias.as_str())) - && db + && services() .sending .send_appservice_request( - &db.globals, registration, - appservice::query::query_room_alias::v1::Request { room_alias }, + appservice::query::query_room_alias::v1::Request { + room_alias: room_alias.clone(), + }, ) .await .is_ok() { - room_id = Some(db.rooms.id_from_alias(room_alias)?.ok_or_else(|| { - Error::bad_config("Appservice lied to us. Room does not exist.") - })?); + room_id = Some( + services() + .rooms + .alias + .resolve_local_alias(&room_alias)? + .ok_or_else(|| { + Error::bad_config("Appservice lied to us. Room does not exist.") + })?, + ); break; } } @@ -146,6 +153,6 @@ pub(crate) async fn get_alias_helper( Ok(get_alias::v3::Response::new( room_id, - vec![db.globals.server_name().to_owned()], + vec![services().globals.server_name().to_owned()], )) } diff --git a/src/client_server/backup.rs b/src/api/client_server/backup.rs index 067f20c..115cba7 100644 --- a/src/client_server/backup.rs +++ b/src/api/client_server/backup.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::api::client::{ backup::{ add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, @@ -14,15 +14,12 @@ use ruma::api::client::{ /// /// Creates a new backup. pub async fn create_backup_version_route( - db: DatabaseGuard, body: Ruma<create_backup_version::v3::Request>, ) -> Result<create_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let version = db + let version = services() .key_backups - .create_backup(sender_user, &body.algorithm, &db.globals)?; - - db.flush()?; + .create_backup(sender_user, &body.algorithm)?; Ok(create_backup_version::v3::Response { version }) } @@ -31,14 +28,12 @@ pub async fn create_backup_version_route( /// /// Update information about an existing backup. Only `auth_data` can be modified. pub async fn update_backup_version_route( - db: DatabaseGuard, - body: Ruma<update_backup_version::v3::IncomingRequest>, + body: Ruma<update_backup_version::v3::Request>, ) -> Result<update_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups - .update_backup(sender_user, &body.version, &body.algorithm, &db.globals)?; - - db.flush()?; + services() + .key_backups + .update_backup(sender_user, &body.version, &body.algorithm)?; Ok(update_backup_version::v3::Response {}) } @@ -47,23 +42,22 @@ pub async fn update_backup_version_route( /// /// Get information about the latest backup version. pub async fn get_latest_backup_info_route( - db: DatabaseGuard, body: Ruma<get_latest_backup_info::v3::Request>, ) -> Result<get_latest_backup_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let (version, algorithm) = - db.key_backups - .get_latest_backup(sender_user)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Key backup does not exist.", - ))?; + let (version, algorithm) = services() + .key_backups + .get_latest_backup(sender_user)? + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Key backup does not exist.", + ))?; Ok(get_latest_backup_info::v3::Response { algorithm, - count: (db.key_backups.count_keys(sender_user, &version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &version)?, + count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &version)?, version, }) } @@ -72,11 +66,10 @@ pub async fn get_latest_backup_info_route( /// /// Get information about an existing backup. pub async fn get_backup_info_route( - db: DatabaseGuard, - body: Ruma<get_backup_info::v3::IncomingRequest>, + body: Ruma<get_backup_info::v3::Request>, ) -> Result<get_backup_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let algorithm = db + let algorithm = services() .key_backups .get_backup(sender_user, &body.version)? .ok_or(Error::BadRequest( @@ -86,8 +79,13 @@ pub async fn get_backup_info_route( Ok(get_backup_info::v3::Response { algorithm, - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, version: body.version.to_owned(), }) } @@ -98,14 +96,13 @@ pub async fn get_backup_info_route( /// /// - Deletes both information about the backup, as well as all key data related to the backup pub async fn delete_backup_version_route( - db: DatabaseGuard, - body: Ruma<delete_backup_version::v3::IncomingRequest>, + body: Ruma<delete_backup_version::v3::Request>, ) -> Result<delete_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups.delete_backup(sender_user, &body.version)?; - - db.flush()?; + services() + .key_backups + .delete_backup(sender_user, &body.version)?; Ok(delete_backup_version::v3::Response {}) } @@ -118,13 +115,12 @@ pub async fn delete_backup_version_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_route( - db: DatabaseGuard, - body: Ruma<add_backup_keys::v3::IncomingRequest>, + body: Ruma<add_backup_keys::v3::Request>, ) -> Result<add_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -137,22 +133,24 @@ pub async fn add_backup_keys_route( for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, room_id, session_id, key_data, - &db.globals, )? } } - db.flush()?; - Ok(add_backup_keys::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -164,13 +162,12 @@ pub async fn add_backup_keys_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_room_route( - db: DatabaseGuard, - body: Ruma<add_backup_keys_for_room::v3::IncomingRequest>, + body: Ruma<add_backup_keys_for_room::v3::Request>, ) -> Result<add_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -182,21 +179,23 @@ pub async fn add_backup_keys_for_room_route( } for (session_id, key_data) in &body.sessions { - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, &body.room_id, session_id, key_data, - &db.globals, )? } - db.flush()?; - Ok(add_backup_keys_for_room::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -208,13 +207,12 @@ pub async fn add_backup_keys_for_room_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_session_route( - db: DatabaseGuard, - body: Ruma<add_backup_keys_for_session::v3::IncomingRequest>, + body: Ruma<add_backup_keys_for_session::v3::Request>, ) -> Result<add_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -225,20 +223,22 @@ pub async fn add_backup_keys_for_session_route( )); } - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data, - &db.globals, )?; - db.flush()?; - Ok(add_backup_keys_for_session::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -246,12 +246,11 @@ pub async fn add_backup_keys_for_session_route( /// /// Retrieves all keys from the backup. pub async fn get_backup_keys_route( - db: DatabaseGuard, - body: Ruma<get_backup_keys::v3::IncomingRequest>, + body: Ruma<get_backup_keys::v3::Request>, ) -> Result<get_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = db.key_backups.get_all(sender_user, &body.version)?; + let rooms = services().key_backups.get_all(sender_user, &body.version)?; Ok(get_backup_keys::v3::Response { rooms }) } @@ -260,12 +259,11 @@ pub async fn get_backup_keys_route( /// /// Retrieves all keys from the backup for a given room. pub async fn get_backup_keys_for_room_route( - db: DatabaseGuard, - body: Ruma<get_backup_keys_for_room::v3::IncomingRequest>, + body: Ruma<get_backup_keys_for_room::v3::Request>, ) -> Result<get_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = db + let sessions = services() .key_backups .get_room(sender_user, &body.version, &body.room_id)?; @@ -276,12 +274,11 @@ pub async fn get_backup_keys_for_room_route( /// /// Retrieves a key from the backup. pub async fn get_backup_keys_for_session_route( - db: DatabaseGuard, - body: Ruma<get_backup_keys_for_session::v3::IncomingRequest>, + body: Ruma<get_backup_keys_for_session::v3::Request>, ) -> Result<get_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = db + let key_data = services() .key_backups .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .ok_or(Error::BadRequest( @@ -296,18 +293,22 @@ pub async fn get_backup_keys_for_session_route( /// /// Delete the keys from the backup. pub async fn delete_backup_keys_route( - db: DatabaseGuard, - body: Ruma<delete_backup_keys::v3::IncomingRequest>, + body: Ruma<delete_backup_keys::v3::Request>, ) -> Result<delete_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups.delete_all_keys(sender_user, &body.version)?; - - db.flush()?; + services() + .key_backups + .delete_all_keys(sender_user, &body.version)?; Ok(delete_backup_keys::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -315,19 +316,22 @@ pub async fn delete_backup_keys_route( /// /// Delete the keys from the backup for a given room. pub async fn delete_backup_keys_for_room_route( - db: DatabaseGuard, - body: Ruma<delete_backup_keys_for_room::v3::IncomingRequest>, + body: Ruma<delete_backup_keys_for_room::v3::Request>, ) -> Result<delete_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups + services() + .key_backups .delete_room_keys(sender_user, &body.version, &body.room_id)?; - db.flush()?; - Ok(delete_backup_keys_for_room::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -335,18 +339,24 @@ pub async fn delete_backup_keys_for_room_route( /// /// Delete a key from the backup. pub async fn delete_backup_keys_for_session_route( - db: DatabaseGuard, - body: Ruma<delete_backup_keys_for_session::v3::IncomingRequest>, + body: Ruma<delete_backup_keys_for_session::v3::Request>, ) -> Result<delete_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups - .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; - - db.flush()?; + services().key_backups.delete_room_key( + sender_user, + &body.version, + &body.room_id, + &body.session_id, + )?; Ok(delete_backup_keys_for_session::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } diff --git a/src/client_server/capabilities.rs b/src/api/client_server/capabilities.rs index 417ad29..233e3c9 100644 --- a/src/client_server/capabilities.rs +++ b/src/api/client_server/capabilities.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{services, Result, Ruma}; use ruma::api::client::discovery::get_capabilities::{ self, Capabilities, RoomVersionStability, RoomVersionsCapability, }; @@ -8,26 +8,19 @@ use std::collections::BTreeMap; /// /// Get information on the supported feature set and other relevent capabilities of this server. pub async fn get_capabilities_route( - db: DatabaseGuard, - _body: Ruma<get_capabilities::v3::IncomingRequest>, + _body: Ruma<get_capabilities::v3::Request>, ) -> Result<get_capabilities::v3::Response> { let mut available = BTreeMap::new(); - if db.globals.allow_unstable_room_versions() { - for room_version in &db.globals.unstable_room_versions { - available.insert(room_version.clone(), RoomVersionStability::Stable); - } - } else { - for room_version in &db.globals.unstable_room_versions { - available.insert(room_version.clone(), RoomVersionStability::Unstable); - } + for room_version in &services().globals.unstable_room_versions { + available.insert(room_version.clone(), RoomVersionStability::Unstable); } - for room_version in &db.globals.stable_room_versions { + for room_version in &services().globals.stable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Stable); } let mut capabilities = Capabilities::new(); capabilities.room_versions = RoomVersionsCapability { - default: db.globals.default_room_version(), + default: services().globals.default_room_version(), available, }; diff --git a/src/client_server/config.rs b/src/api/client_server/config.rs index 6184e0b..12f9aea 100644 --- a/src/client_server/config.rs +++ b/src/api/client_server/config.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::{ api::client::{ config::{ @@ -17,8 +17,7 @@ use serde_json::{json, value::RawValue as RawJsonValue}; /// /// Sets some account data for the sender user. pub async fn set_global_account_data_route( - db: DatabaseGuard, - body: Ruma<set_global_account_data::v3::IncomingRequest>, + body: Ruma<set_global_account_data::v3::Request>, ) -> Result<set_global_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -27,7 +26,7 @@ pub async fn set_global_account_data_route( let event_type = body.event_type.to_string(); - db.account_data.update( + services().account_data.update( None, sender_user, event_type.clone().into(), @@ -35,11 +34,8 @@ pub async fn set_global_account_data_route( "type": event_type, "content": data, }), - &db.globals, )?; - db.flush()?; - Ok(set_global_account_data::v3::Response {}) } @@ -47,8 +43,7 @@ pub async fn set_global_account_data_route( /// /// Sets some room account data for the sender user. pub async fn set_room_account_data_route( - db: DatabaseGuard, - body: Ruma<set_room_account_data::v3::IncomingRequest>, + body: Ruma<set_room_account_data::v3::Request>, ) -> Result<set_room_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -57,7 +52,7 @@ pub async fn set_room_account_data_route( let event_type = body.event_type.to_string(); - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, event_type.clone().into(), @@ -65,11 +60,8 @@ pub async fn set_room_account_data_route( "type": event_type, "content": data, }), - &db.globals, )?; - db.flush()?; - Ok(set_room_account_data::v3::Response {}) } @@ -77,12 +69,11 @@ pub async fn set_room_account_data_route( /// /// Gets some account data for the sender user. pub async fn get_global_account_data_route( - db: DatabaseGuard, - body: Ruma<get_global_account_data::v3::IncomingRequest>, + body: Ruma<get_global_account_data::v3::Request>, ) -> Result<get_global_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = db + let event: Box<RawJsonValue> = services() .account_data .get(None, sender_user, body.event_type.clone().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; @@ -98,12 +89,11 @@ pub async fn get_global_account_data_route( /// /// Gets some room account data for the sender user. pub async fn get_room_account_data_route( - db: DatabaseGuard, - body: Ruma<get_room_account_data::v3::IncomingRequest>, + body: Ruma<get_room_account_data::v3::Request>, ) -> Result<get_room_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = db + let event: Box<RawJsonValue> = services() .account_data .get( Some(&body.room_id), diff --git a/src/client_server/context.rs b/src/api/client_server/context.rs index e93f5a5..1e62f91 100644 --- a/src/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::{ api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, events::StateEventType, @@ -13,8 +13,7 @@ use tracing::error; /// - Only works if the user is joined (TODO: always allow, but only show events if the user was /// joined, depending on history_visibility) pub async fn get_context_route( - db: DatabaseGuard, - body: Ruma<get_context::v3::IncomingRequest>, + body: Ruma<get_context::v3::Request>, ) -> Result<get_context::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -28,18 +27,20 @@ pub async fn get_context_route( let mut lazy_loaded = HashSet::new(); - let base_pdu_id = db + let base_pdu_id = services() .rooms + .timeline .get_pdu_id(&body.event_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "Base event id not found.", ))?; - let base_token = db.rooms.pdu_count(&base_pdu_id)?; + let base_token = services().rooms.timeline.pdu_count(&base_pdu_id)?; - let base_event = db + let base_event = services() .rooms + .timeline .get_pdu_from_id(&base_pdu_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -48,14 +49,18 @@ pub async fn get_context_route( let room_id = base_event.room_id.clone(); - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, &room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -67,8 +72,9 @@ pub async fn get_context_route( let base_event = base_event.to_room_event(); - let events_before: Vec<_> = db + let events_before: Vec<_> = services() .rooms + .timeline .pdus_until(sender_user, &room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { @@ -80,7 +86,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_before { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -93,7 +99,7 @@ pub async fn get_context_route( let start_token = events_before .last() - .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_before: Vec<_> = events_before @@ -101,8 +107,9 @@ pub async fn get_context_route( .map(|(_, pdu)| pdu.to_room_event()) .collect(); - let events_after: Vec<_> = db + let events_after: Vec<_> = services() .rooms + .timeline .pdus_after(sender_user, &room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { @@ -114,7 +121,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_after { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -125,23 +132,28 @@ pub async fn get_context_route( } } - let shortstatehash = match db.rooms.pdu_shortstatehash( + let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash( events_after .last() .map_or(&*body.event_id, |(_, e)| &*e.event_id), )? { Some(s) => s, - None => db + None => services() .rooms - .current_shortstatehash(&room_id)? + .state + .get_room_shortstatehash(&room_id)? .expect("All rooms have state"), }; - let state_ids = db.rooms.state_full_ids(shortstatehash).await?; + let state_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await?; let end_token = events_after .last() - .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_after: Vec<_> = events_after @@ -152,10 +164,13 @@ pub async fn get_context_route( let mut state = Vec::new(); for (shortstatekey, id) in state_ids { - let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services() + .rooms + .short + .get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -164,7 +179,7 @@ pub async fn get_context_route( }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); diff --git a/src/client_server/device.rs b/src/api/client_server/device.rs index b100bf2..aba061b 100644 --- a/src/client_server/device.rs +++ b/src/api/client_server/device.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{services, utils, Error, Result, Ruma}; use ruma::api::client::{ device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, error::ErrorKind, @@ -11,12 +11,11 @@ use super::SESSION_ID_LENGTH; /// /// Get metadata on all devices of the sender user. pub async fn get_devices_route( - db: DatabaseGuard, body: Ruma<get_devices::v3::Request>, ) -> Result<get_devices::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let devices: Vec<device::Device> = db + let devices: Vec<device::Device> = services() .users .all_devices_metadata(sender_user) .filter_map(|r| r.ok()) // Filter out buggy devices @@ -29,12 +28,11 @@ pub async fn get_devices_route( /// /// Get metadata on a single device of the sender user. pub async fn get_device_route( - db: DatabaseGuard, - body: Ruma<get_device::v3::IncomingRequest>, + body: Ruma<get_device::v3::Request>, ) -> Result<get_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let device = db + let device = services() .users .get_device_metadata(sender_user, &body.body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; @@ -46,23 +44,21 @@ pub async fn get_device_route( /// /// Updates the metadata on a given device of the sender user. pub async fn update_device_route( - db: DatabaseGuard, - body: Ruma<update_device::v3::IncomingRequest>, + body: Ruma<update_device::v3::Request>, ) -> Result<update_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut device = db + let mut device = services() .users .get_device_metadata(sender_user, &body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; device.display_name = body.display_name.clone(); - db.users + services() + .users .update_device_metadata(sender_user, &body.device_id, &device)?; - db.flush()?; - Ok(update_device::v3::Response {}) } @@ -76,8 +72,7 @@ pub async fn update_device_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn delete_device_route( - db: DatabaseGuard, - body: Ruma<delete_device::v3::IncomingRequest>, + body: Ruma<delete_device::v3::Request>, ) -> Result<delete_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -94,30 +89,27 @@ pub async fn delete_device_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - sender_user, - sender_device, - auth, - &uiaainfo, - &db.users, - &db.globals, - )?; + let (worked, uiaainfo) = + services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services() + .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - db.users.remove_device(sender_user, &body.device_id)?; - - db.flush()?; + services() + .users + .remove_device(sender_user, &body.device_id)?; Ok(delete_device::v3::Response {}) } @@ -134,8 +126,7 @@ pub async fn delete_device_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn delete_devices_route( - db: DatabaseGuard, - body: Ruma<delete_devices::v3::IncomingRequest>, + body: Ruma<delete_devices::v3::Request>, ) -> Result<delete_devices::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -152,21 +143,18 @@ pub async fn delete_devices_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - sender_user, - sender_device, - auth, - &uiaainfo, - &db.users, - &db.globals, - )?; + let (worked, uiaainfo) = + services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services() + .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -174,10 +162,8 @@ pub async fn delete_devices_route( } for device_id in &body.devices { - db.users.remove_device(sender_user, device_id)? + services().users.remove_device(sender_user, device_id)? } - db.flush()?; - Ok(delete_devices::v3::Response {}) } diff --git a/src/client_server/directory.rs b/src/api/client_server/directory.rs index 4e4a322..e132210 100644 --- a/src/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::{ api::{ client::{ @@ -11,14 +11,12 @@ use ruma::{ }, federation, }, - directory::{ - Filter, IncomingFilter, IncomingRoomNetwork, PublicRoomJoinRule, PublicRoomsChunk, - RoomNetwork, - }, + directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork}, events::{ room::{ avatar::RoomAvatarEventContent, canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, join_rules::{JoinRule, RoomJoinRulesEventContent}, @@ -29,7 +27,7 @@ use ruma::{ }, ServerName, UInt, }; -use tracing::{info, warn}; +use tracing::{error, info, warn}; /// # `POST /_matrix/client/r0/publicRooms` /// @@ -37,11 +35,9 @@ use tracing::{info, warn}; /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_filtered_route( - db: DatabaseGuard, - body: Ruma<get_public_rooms_filtered::v3::IncomingRequest>, + body: Ruma<get_public_rooms_filtered::v3::Request>, ) -> Result<get_public_rooms_filtered::v3::Response> { get_public_rooms_filtered_helper( - &db, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -57,16 +53,14 @@ pub async fn get_public_rooms_filtered_route( /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_route( - db: DatabaseGuard, - body: Ruma<get_public_rooms::v3::IncomingRequest>, + body: Ruma<get_public_rooms::v3::Request>, ) -> Result<get_public_rooms::v3::Response> { let response = get_public_rooms_filtered_helper( - &db, body.server.as_deref(), body.limit, body.since.as_deref(), - &IncomingFilter::default(), - &IncomingRoomNetwork::Matrix, + &Filter::default(), + &RoomNetwork::Matrix, ) .await?; @@ -84,17 +78,21 @@ pub async fn get_public_rooms_route( /// /// - TODO: Access control checks pub async fn set_room_visibility_route( - db: DatabaseGuard, - body: Ruma<set_room_visibility::v3::IncomingRequest>, + body: Ruma<set_room_visibility::v3::Request>, ) -> Result<set_room_visibility::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + if !services().rooms.metadata.exists(&body.room_id)? { + // Return 404 if the room doesn't exist + return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); + } + match &body.visibility { room::Visibility::Public => { - db.rooms.set_public(&body.room_id, true)?; + services().rooms.directory.set_public(&body.room_id)?; info!("{} made {} public", sender_user, body.room_id); } - room::Visibility::Private => db.rooms.set_public(&body.room_id, false)?, + room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -103,8 +101,6 @@ pub async fn set_room_visibility_route( } } - db.flush()?; - Ok(set_room_visibility::v3::Response {}) } @@ -112,11 +108,15 @@ pub async fn set_room_visibility_route( /// /// Gets the visibility of a given room in the room directory. pub async fn get_room_visibility_route( - db: DatabaseGuard, - body: Ruma<get_room_visibility::v3::IncomingRequest>, + body: Ruma<get_room_visibility::v3::Request>, ) -> Result<get_room_visibility::v3::Response> { + if !services().rooms.metadata.exists(&body.room_id)? { + // Return 404 if the room doesn't exist + return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); + } + Ok(get_room_visibility::v3::Response { - visibility: if db.rooms.is_public_room(&body.room_id)? { + visibility: if services().rooms.directory.is_public_room(&body.room_id)? { room::Visibility::Public } else { room::Visibility::Private @@ -125,25 +125,25 @@ pub async fn get_room_visibility_route( } pub(crate) async fn get_public_rooms_filtered_helper( - db: &Database, server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, - filter: &IncomingFilter, - _network: &IncomingRoomNetwork, + filter: &Filter, + _network: &RoomNetwork, ) -> Result<get_public_rooms_filtered::v3::Response> { - if let Some(other_server) = server.filter(|server| *server != db.globals.server_name().as_str()) + if let Some(other_server) = + server.filter(|server| *server != services().globals.server_name().as_str()) { - let response = db + let response = services() .sending .send_federation_request( - &db.globals, other_server, federation::directory::get_public_rooms_filtered::v1::Request { limit, - since, + since: since.map(ToOwned::to_owned), filter: Filter { - generic_search_term: filter.generic_search_term.as_deref(), + generic_search_term: filter.generic_search_term.clone(), + room_types: filter.room_types.clone(), }, room_network: RoomNetwork::Matrix, }, @@ -184,15 +184,17 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = db + let mut all_rooms: Vec<_> = services() .rooms + .directory .public_rooms() .map(|room_id| { let room_id = room_id?; let chunk = PublicRoomsChunk { - canonical_alias: db + canonical_alias: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) @@ -201,8 +203,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid canonical alias event in database.") }) })?, - name: db + name: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomName, "")? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) @@ -211,8 +214,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room name event in database.") }) })?, - num_joined_members: db + num_joined_members: services() .rooms + .state_cache .room_joined_count(&room_id)? .unwrap_or_else(|| { warn!("Room {} has no member count", room_id); @@ -220,8 +224,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( }) .try_into() .expect("user count should not be that big"), - topic: db + topic: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomTopic, "")? .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) @@ -230,8 +235,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room topic event in database.") }) })?, - world_readable: db + world_readable: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) @@ -244,8 +250,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( ) }) })?, - guest_can_join: db + guest_can_join: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) @@ -256,8 +263,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room guest access event in database.") }) })?, - avatar_url: db + avatar_url: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomAvatar, "")? .map(|s| { serde_json::from_str(s.content.get()) @@ -269,8 +277,9 @@ pub(crate) async fn get_public_rooms_filtered_helper( .transpose()? // url is now an Option<String> so we must flatten .flatten(), - join_rule: db + join_rule: services() .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? .map(|s| { serde_json::from_str(s.content.get()) @@ -279,15 +288,28 @@ pub(crate) async fn get_public_rooms_filtered_helper( JoinRule::Knock => Some(PublicRoomJoinRule::Knock), _ => None, }) - .map_err(|_| { - Error::bad_database("Invalid room join rule event in database.") + .map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") }) }) .transpose()? .flatten() - .ok_or(Error::bad_database( - "Invalid room join rule event in database.", - ))?, + .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, + room_type: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCreate, "")? + .map(|s| { + serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err( + |e| { + error!("Invalid room create event in database: {}", e); + Error::BadDatabase("Invalid room create event in database.") + }, + ) + }) + .transpose()? + .and_then(|e| e.room_type), room_id, }; Ok(chunk) @@ -339,7 +361,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( let prev_batch = if num_since == 0 { None } else { - Some(format!("p{}", num_since)) + Some(format!("p{num_since}")) }; let next_batch = if chunk.len() < limit as usize { diff --git a/src/client_server/filter.rs b/src/api/client_server/filter.rs index 6522c90..e9a359d 100644 --- a/src/client_server/filter.rs +++ b/src/api/client_server/filter.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::api::client::{ error::ErrorKind, filter::{create_filter, get_filter}, @@ -10,11 +10,10 @@ use ruma::api::client::{ /// /// - A user can only access their own filters pub async fn get_filter_route( - db: DatabaseGuard, - body: Ruma<get_filter::v3::IncomingRequest>, + body: Ruma<get_filter::v3::Request>, ) -> Result<get_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let filter = match db.users.get_filter(sender_user, &body.filter_id)? { + let filter = match services().users.get_filter(sender_user, &body.filter_id)? { Some(filter) => filter, None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")), }; @@ -26,11 +25,10 @@ pub async fn get_filter_route( /// /// Creates a new filter to be used by other endpoints. pub async fn create_filter_route( - db: DatabaseGuard, - body: Ruma<create_filter::v3::IncomingRequest>, + body: Ruma<create_filter::v3::Request>, ) -> Result<create_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(create_filter::v3::Response::new( - db.users.create_filter(sender_user, &body.filter)?, + services().users.create_filter(sender_user, &body.filter)?, )) } diff --git a/src/client_server/keys.rs b/src/api/client_server/keys.rs index c4f91cb..ba89ece 100644 --- a/src/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -1,5 +1,5 @@ use super::SESSION_ID_LENGTH; -use crate::{database::DatabaseGuard, utils, Database, Error, Result, Ruma}; +use crate::{services, utils, Error, Result, Ruma}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ @@ -14,7 +14,7 @@ use ruma::{ federation, }, serde::Raw, - DeviceId, DeviceKeyAlgorithm, UserId, + DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; use std::collections::{BTreeMap, HashMap, HashSet}; @@ -26,39 +26,35 @@ use std::collections::{BTreeMap, HashMap, HashSet}; /// - Adds one time keys /// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) pub async fn upload_keys_route( - db: DatabaseGuard, body: Ruma<upload_keys::v3::Request>, ) -> Result<upload_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); for (key_key, key_value) in &body.one_time_keys { - db.users - .add_one_time_key(sender_user, sender_device, key_key, key_value, &db.globals)?; + services() + .users + .add_one_time_key(sender_user, sender_device, key_key, key_value)?; } if let Some(device_keys) = &body.device_keys { // TODO: merge this and the existing event? // This check is needed to assure that signatures are kept - if db + if services() .users .get_device_keys(sender_user, sender_device)? .is_none() { - db.users.add_device_keys( - sender_user, - sender_device, - device_keys, - &db.rooms, - &db.globals, - )?; + services() + .users + .add_device_keys(sender_user, sender_device, device_keys)?; } } - db.flush()?; - Ok(upload_keys::v3::Response { - one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?, + one_time_key_counts: services() + .users + .count_one_time_keys(sender_user, sender_device)?, }) } @@ -69,19 +65,11 @@ pub async fn upload_keys_route( /// - Always fetches users from other servers over federation /// - Gets master keys, self-signing keys, user signing keys and device keys. /// - The master and self-signing keys contain signatures that the user is allowed to see -pub async fn get_keys_route( - db: DatabaseGuard, - body: Ruma<get_keys::v3::IncomingRequest>, -) -> Result<get_keys::v3::Response> { +pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let response = get_keys_helper( - Some(sender_user), - &body.device_keys, - |u| u == sender_user, - &db, - ) - .await?; + let response = + get_keys_helper(Some(sender_user), &body.device_keys, |u| u == sender_user).await?; Ok(response) } @@ -90,12 +78,9 @@ pub async fn get_keys_route( /// /// Claims one-time keys pub async fn claim_keys_route( - db: DatabaseGuard, body: Ruma<claim_keys::v3::Request>, ) -> Result<claim_keys::v3::Response> { - let response = claim_keys_helper(&body.one_time_keys, &db).await?; - - db.flush()?; + let response = claim_keys_helper(&body.one_time_keys).await?; Ok(response) } @@ -106,8 +91,7 @@ pub async fn claim_keys_route( /// /// - Requires UIAA to verify password pub async fn upload_signing_keys_route( - db: DatabaseGuard, - body: Ruma<upload_signing_keys::v3::IncomingRequest>, + body: Ruma<upload_signing_keys::v3::Request>, ) -> Result<upload_signing_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -124,21 +108,18 @@ pub async fn upload_signing_keys_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - sender_user, - sender_device, - auth, - &uiaainfo, - &db.users, - &db.globals, - )?; + let (worked, uiaainfo) = + services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services() + .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -146,18 +127,14 @@ pub async fn upload_signing_keys_route( } if let Some(master_key) = &body.master_key { - db.users.add_cross_signing_keys( + services().users.add_cross_signing_keys( sender_user, master_key, &body.self_signing_key, &body.user_signing_key, - &db.rooms, - &db.globals, )?; } - db.flush()?; - Ok(upload_signing_keys::v3::Response {}) } @@ -165,16 +142,28 @@ pub async fn upload_signing_keys_route( /// /// Uploads end-to-end key signatures from the sender user. pub async fn upload_signatures_route( - db: DatabaseGuard, body: Ruma<upload_signatures::v3::Request>, ) -> Result<upload_signatures::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for (user_id, signed_keys) in &body.signed_keys { - for (key_id, signed_key) in signed_keys { - let signed_key = serde_json::to_value(signed_key).unwrap(); + for (user_id, keys) in &body.signed_keys { + for (key_id, key) in keys { + let key = serde_json::to_value(key) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?; + + let is_signed_key = match key.get("usage") { + Some(usage) => usage + .as_array() + .map(|usage| !usage.contains(&json!("master"))) + .unwrap_or(false), + None => true, + }; + + if !is_signed_key { + continue; + } - for signature in signed_key + for signature in key .get("signatures") .ok_or(Error::BadRequest( ErrorKind::InvalidParam, @@ -205,20 +194,13 @@ pub async fn upload_signatures_route( ))? .to_owned(), ); - db.users.sign_key( - user_id, - key_id, - signature, - sender_user, - &db.rooms, - &db.globals, - )?; + services() + .users + .sign_key(user_id, key_id, signature, sender_user)?; } } } - db.flush()?; - Ok(upload_signatures::v3::Response { failures: BTreeMap::new(), // TODO: integrate }) @@ -230,15 +212,15 @@ pub async fn upload_signatures_route( /// /// - TODO: left users pub async fn get_key_changes_route( - db: DatabaseGuard, - body: Ruma<get_key_changes::v3::IncomingRequest>, + body: Ruma<get_key_changes::v3::Request>, ) -> Result<get_key_changes::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut device_list_updates = HashSet::new(); device_list_updates.extend( - db.users + services() + .users .keys_changed( sender_user.as_str(), body.from @@ -253,11 +235,17 @@ pub async fn get_key_changes_route( .filter_map(|r| r.ok()), ); - for room_id in db.rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) { + for room_id in services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(|r| r.ok()) + { device_list_updates.extend( - db.users + services() + .users .keys_changed( - &room_id.to_string(), + room_id.as_ref(), body.from.parse().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.") })?, @@ -276,9 +264,8 @@ pub async fn get_key_changes_route( pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( sender_user: Option<&UserId>, - device_keys_input: &BTreeMap<Box<UserId>, Vec<Box<DeviceId>>>, + device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F, - db: &Database, ) -> Result<get_keys::v3::Response> { let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); @@ -288,9 +275,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( let mut get_over_federation = HashMap::new(); for (user_id, device_ids) in device_keys_input { - let user_id: &UserId = &**user_id; + let user_id: &UserId = user_id; - if user_id.server_name() != db.globals.server_name() { + if user_id.server_name() != services().globals.server_name() { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) @@ -300,10 +287,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in db.users.all_device_ids(user_id) { + for device_id in services().users.all_device_ids(user_id) { let device_id = device_id?; - if let Some(mut keys) = db.users.get_device_keys(user_id, &device_id)? { - let metadata = db + if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { + let metadata = services() .users .get_device_metadata(user_id, &device_id)? .ok_or_else(|| { @@ -319,13 +306,14 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = db.users.get_device_keys(user_id, device_id)? { - let metadata = db.users.get_device_metadata(user_id, device_id)?.ok_or( - Error::BadRequest( + if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { + let metadata = services() + .users + .get_device_metadata(user_id, device_id)? + .ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Tried to get keys for nonexistent device.", - ), - )?; + ))?; add_unsigned_device_display_name(&mut keys, metadata) .map_err(|_| Error::bad_database("invalid device keys in database"))?; @@ -335,17 +323,20 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( } } - if let Some(master_key) = db.users.get_master_key(user_id, &allowed_signatures)? { + if let Some(master_key) = services() + .users + .get_master_key(user_id, &allowed_signatures)? + { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = db + if let Some(self_signing_key) = services() .users .get_self_signing_key(user_id, &allowed_signatures)? { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = db.users.get_user_signing_key(user_id)? { + if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -362,9 +353,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( } ( server, - db.sending + services() + .sending .send_federation_request( - &db.globals, server, federation::keys::get_keys::v1::Request { device_keys: device_keys_input_fed, @@ -416,15 +407,14 @@ fn add_unsigned_device_display_name( } pub(crate) async fn claim_keys_helper( - one_time_keys_input: &BTreeMap<Box<UserId>, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>, - db: &Database, + one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>, ) -> Result<claim_keys::v3::Response> { let mut one_time_keys = BTreeMap::new(); let mut get_over_federation = BTreeMap::new(); for (user_id, map) in one_time_keys_input { - if user_id.server_name() != db.globals.server_name() { + if user_id.server_name() != services().globals.server_name() { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) @@ -434,8 +424,9 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { if let Some(one_time_keys) = - db.users - .take_one_time_key(user_id, device_id, key_algorithm, &db.globals)? + services() + .users + .take_one_time_key(user_id, device_id, key_algorithm)? { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); @@ -447,26 +438,36 @@ pub(crate) async fn claim_keys_helper( let mut failures = BTreeMap::new(); - for (server, vec) in get_over_federation { - let mut one_time_keys_input_fed = BTreeMap::new(); - for (user_id, keys) in vec { - one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); - } - // Ignore failures - if let Ok(keys) = db - .sending - .send_federation_request( - &db.globals, + let mut futures: FuturesUnordered<_> = get_over_federation + .into_iter() + .map(|(server, vec)| async move { + let mut one_time_keys_input_fed = BTreeMap::new(); + for (user_id, keys) in vec { + one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); + } + ( server, - federation::keys::claim_keys::v1::Request { - one_time_keys: one_time_keys_input_fed, - }, + services() + .sending + .send_federation_request( + server, + federation::keys::claim_keys::v1::Request { + one_time_keys: one_time_keys_input_fed, + }, + ) + .await, ) - .await - { - one_time_keys.extend(keys.one_time_keys); - } else { - failures.insert(server.to_string(), json!({})); + }) + .collect(); + + while let Some((server, response)) = futures.next().await { + match response { + Ok(keys) => { + one_time_keys.extend(keys.one_time_keys); + } + Err(_e) => { + failures.insert(server.to_string(), json!({})); + } } } diff --git a/src/client_server/media.rs b/src/api/client_server/media.rs index a9a6d6c..3410cc0 100644 --- a/src/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -1,7 +1,4 @@ -use crate::{ - database::{media::FileMeta, DatabaseGuard}, - utils, Error, Result, Ruma, -}; +use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma}; use ruma::api::client::{ error::ErrorKind, media::{ @@ -16,11 +13,10 @@ const MXC_LENGTH: usize = 32; /// /// Returns max upload size. pub async fn get_media_config_route( - db: DatabaseGuard, _body: Ruma<get_media_config::v3::Request>, ) -> Result<get_media_config::v3::Response> { Ok(get_media_config::v3::Response { - upload_size: db.globals.max_request_size().into(), + upload_size: services().globals.max_request_size().into(), }) } @@ -31,31 +27,27 @@ pub async fn get_media_config_route( /// - Some metadata will be saved in the database /// - Media will be saved in the media/ directory pub async fn create_content_route( - db: DatabaseGuard, - body: Ruma<create_content::v3::IncomingRequest>, + body: Ruma<create_content::v3::Request>, ) -> Result<create_content::v3::Response> { let mxc = format!( "mxc://{}/{}", - db.globals.server_name(), + services().globals.server_name(), utils::random_string(MXC_LENGTH) ); - db.media + services() + .media .create( mxc.clone(), - &db.globals, - &body - .filename + body.filename .as_ref() .map(|filename| "inline; filename=".to_owned() + filename) .as_deref(), - &body.content_type.as_deref(), + body.content_type.as_deref(), &body.file, ) .await?; - db.flush()?; - Ok(create_content::v3::Response { content_uri: mxc.try_into().expect("Invalid mxc:// URI"), blurhash: None, @@ -63,30 +55,28 @@ pub async fn create_content_route( } pub async fn get_remote_content( - db: &DatabaseGuard, mxc: &str, server_name: &ruma::ServerName, - media_id: &str, + media_id: String, ) -> Result<get_content::v3::Response, Error> { - let content_response = db + let content_response = services() .sending .send_federation_request( - &db.globals, server_name, get_content::v3::Request { allow_remote: false, - server_name, + server_name: server_name.to_owned(), media_id, }, ) .await?; - db.media + services() + .media .create( - mxc.to_string(), - &db.globals, - &content_response.content_disposition.as_deref(), - &content_response.content_type.as_deref(), + mxc.to_owned(), + content_response.content_disposition.as_deref(), + content_response.content_type.as_deref(), &content_response.file, ) .await?; @@ -100,8 +90,7 @@ pub async fn get_remote_content( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_route( - db: DatabaseGuard, - body: Ruma<get_content::v3::IncomingRequest>, + body: Ruma<get_content::v3::Request>, ) -> Result<get_content::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -109,16 +98,17 @@ pub async fn get_content_route( content_disposition, content_type, file, - }) = db.media.get(&db.globals, &mxc).await? + }) = services().media.get(mxc.clone()).await? { Ok(get_content::v3::Response { file, content_type, content_disposition, + cross_origin_resource_policy: Some("cross-origin".to_owned()), }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { let remote_content_response = - get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; + get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?; Ok(remote_content_response) } else { Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) @@ -131,8 +121,7 @@ pub async fn get_content_route( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_as_filename_route( - db: DatabaseGuard, - body: Ruma<get_content_as_filename::v3::IncomingRequest>, + body: Ruma<get_content_as_filename::v3::Request>, ) -> Result<get_content_as_filename::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -140,21 +129,23 @@ pub async fn get_content_as_filename_route( content_disposition: _, content_type, file, - }) = db.media.get(&db.globals, &mxc).await? + }) = services().media.get(mxc.clone()).await? { Ok(get_content_as_filename::v3::Response { file, content_type, content_disposition: Some(format!("inline; filename={}", body.filename)), + cross_origin_resource_policy: Some("cross-origin".to_owned()), }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { let remote_content_response = - get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; + get_remote_content(&mxc, &body.server_name, body.media_id.clone()).await?; Ok(get_content_as_filename::v3::Response { content_disposition: Some(format!("inline: filename={}", body.filename)), content_type: remote_content_response.content_type, file: remote_content_response.file, + cross_origin_resource_policy: Some("cross-origin".to_owned()), }) } else { Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) @@ -167,18 +158,16 @@ pub async fn get_content_as_filename_route( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_thumbnail_route( - db: DatabaseGuard, - body: Ruma<get_content_thumbnail::v3::IncomingRequest>, + body: Ruma<get_content_thumbnail::v3::Request>, ) -> Result<get_content_thumbnail::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { content_type, file, .. - }) = db + }) = services() .media .get_thumbnail( - &mxc, - &db.globals, + mxc.clone(), body.width .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, @@ -188,30 +177,33 @@ pub async fn get_content_thumbnail_route( ) .await? { - Ok(get_content_thumbnail::v3::Response { file, content_type }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { - let get_thumbnail_response = db + Ok(get_content_thumbnail::v3::Response { + file, + content_type, + cross_origin_resource_policy: Some("cross-origin".to_owned()), + }) + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + let get_thumbnail_response = services() .sending .send_federation_request( - &db.globals, &body.server_name, get_content_thumbnail::v3::Request { allow_remote: false, height: body.height, width: body.width, method: body.method.clone(), - server_name: &body.server_name, - media_id: &body.media_id, + server_name: body.server_name.clone(), + media_id: body.media_id.clone(), }, ) .await?; - db.media + services() + .media .upload_thumbnail( mxc, - &db.globals, - &None, - &get_thumbnail_response.content_type, + None, + get_thumbnail_response.content_type.as_deref(), body.width.try_into().expect("all UInts are valid u32s"), body.height.try_into().expect("all UInts are valid u32s"), &get_thumbnail_response.file, diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs new file mode 100644 index 0000000..61c67cb --- /dev/null +++ b/src/api/client_server/membership.rs @@ -0,0 +1,1530 @@ +use ruma::{ + api::{ + client::{ + error::ErrorKind, + membership::{ + ban_user, forget_room, get_member_events, invite_user, join_room_by_id, + join_room_by_id_or_alias, joined_members, joined_rooms, kick_user, leave_room, + unban_user, ThirdPartySigned, + }, + }, + federation::{self, membership::create_invite}, + }, + canonical_json::to_canonical_value, + events::{ + room::{ + join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + power_levels::RoomPowerLevelsEventContent, + }, + RoomEventType, StateEventType, + }, + serde::Base64, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, + OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, +}; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use std::{ + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; +use tracing::{debug, error, warn}; + +use crate::{ + service::pdu::{gen_event_id_canonical_json, PduBuilder}, + services, utils, Error, PduEvent, Result, Ruma, +}; + +use super::get_alias_helper; + +/// # `POST /_matrix/client/r0/rooms/{roomId}/join` +/// +/// Tries to join the sender user into a room. +/// +/// - If the server knowns about this room: creates the join event and does auth rules locally +/// - If the server does not know about the room: asks other servers over federation +pub async fn join_room_by_id_route( + body: Ruma<join_room_by_id::v3::Request>, +) -> Result<join_room_by_id::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let mut servers = Vec::new(); // There is no body.server_name for /roomId/join + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + + servers.push(body.room_id.server_name().to_owned()); + + join_room_by_id_helper( + body.sender_user.as_deref(), + &body.room_id, + body.reason.clone(), + &servers, + body.third_party_signed.as_ref(), + ) + .await +} + +/// # `POST /_matrix/client/r0/join/{roomIdOrAlias}` +/// +/// Tries to join the sender user into a room. +/// +/// - If the server knowns about this room: creates the join event and does auth rules locally +/// - If the server does not know about the room: asks other servers over federation +pub async fn join_room_by_id_or_alias_route( + body: Ruma<join_room_by_id_or_alias::v3::Request>, +) -> Result<join_room_by_id_or_alias::v3::Response> { + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); + let body = body.body; + + let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { + Ok(room_id) => { + let mut servers = body.server_name.clone(); + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); + + servers.push(room_id.server_name().to_owned()); + (servers, room_id) + } + Err(room_alias) => { + let response = get_alias_helper(room_alias).await?; + + (response.servers.into_iter().collect(), response.room_id) + } + }; + + let join_room_response = join_room_by_id_helper( + Some(sender_user), + &room_id, + body.reason.clone(), + &servers, + body.third_party_signed.as_ref(), + ) + .await?; + + Ok(join_room_by_id_or_alias::v3::Response { + room_id: join_room_response.room_id, + }) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/leave` +/// +/// Tries to leave the sender user from a room. +/// +/// - This should always work if the user is currently joined. +pub async fn leave_room_route( + body: Ruma<leave_room::v3::Request>, +) -> Result<leave_room::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + leave_room(sender_user, &body.room_id, body.reason.clone()).await?; + + Ok(leave_room::v3::Response::new()) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/invite` +/// +/// Tries to send an invite event into the room. +pub async fn invite_user_route( + body: Ruma<invite_user::v3::Request>, +) -> Result<invite_user::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if let invite_user::v3::InvitationRecipient::UserId { user_id } = &body.recipient { + invite_helper( + sender_user, + user_id, + &body.room_id, + body.reason.clone(), + false, + ) + .await?; + Ok(invite_user::v3::Response {}) + } else { + Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) + } +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/kick` +/// +/// Tries to send a kick event into the room. +pub async fn kick_user_route( + body: Ruma<kick_user::v3::Request>, +) -> Result<kick_user::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let mut event: RoomMemberEventContent = serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get( + &body.room_id, + &StateEventType::RoomMember, + body.user_id.as_ref(), + )? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "Cannot kick member that's not in the room.", + ))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; + + event.membership = MembershipState::Leave; + event.reason = body.reason.clone(); + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + )?; + + drop(state_lock); + + Ok(kick_user::v3::Response::new()) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/ban` +/// +/// Tries to send a ban event into the room. +pub async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result<ban_user::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let event = services() + .rooms + .state_accessor + .room_state_get( + &body.room_id, + &StateEventType::RoomMember, + body.user_id.as_ref(), + )? + .map_or( + Ok(RoomMemberEventContent { + membership: MembershipState::Ban, + displayname: services().users.displayname(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(&body.user_id)?, + reason: body.reason.clone(), + join_authorized_via_users_server: None, + }), + |event| { + serde_json::from_str(event.content.get()) + .map(|event: RoomMemberEventContent| RoomMemberEventContent { + membership: MembershipState::Ban, + ..event + }) + .map_err(|_| Error::bad_database("Invalid member event in database.")) + }, + )?; + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + )?; + + drop(state_lock); + + Ok(ban_user::v3::Response::new()) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/unban` +/// +/// Tries to send an unban event into the room. +pub async fn unban_user_route( + body: Ruma<unban_user::v3::Request>, +) -> Result<unban_user::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let mut event: RoomMemberEventContent = serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get( + &body.room_id, + &StateEventType::RoomMember, + body.user_id.as_ref(), + )? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "Cannot unban a user who is not banned.", + ))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; + + event.membership = MembershipState::Leave; + event.reason = body.reason.clone(); + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + )?; + + drop(state_lock); + + Ok(unban_user::v3::Response::new()) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/forget` +/// +/// Forgets about a room. +/// +/// - If the sender user currently left the room: Stops sender user from receiving information about the room +/// +/// Note: Other devices of the user have no way of knowing the room was forgotten, so this has to +/// be called from every device +pub async fn forget_room_route( + body: Ruma<forget_room::v3::Request>, +) -> Result<forget_room::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + services() + .rooms + .state_cache + .forget(&body.room_id, sender_user)?; + + Ok(forget_room::v3::Response::new()) +} + +/// # `POST /_matrix/client/r0/joined_rooms` +/// +/// Lists all rooms the user has joined. +pub async fn joined_rooms_route( + body: Ruma<joined_rooms::v3::Request>, +) -> Result<joined_rooms::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + Ok(joined_rooms::v3::Response { + joined_rooms: services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(|r| r.ok()) + .collect(), + }) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/members` +/// +/// Lists all joined users in a room (TODO: at a specific point in time, with a specific membership). +/// +/// - Only works if the user is currently joined +pub async fn get_member_events_route( + body: Ruma<get_member_events::v3::Request>, +) -> Result<get_member_events::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + // TODO: check history visibility? + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } + + Ok(get_member_events::v3::Response { + chunk: services() + .rooms + .state_accessor + .room_state_full(&body.room_id) + .await? + .iter() + .filter(|(key, _)| key.0 == StateEventType::RoomMember) + .map(|(_, pdu)| pdu.to_member_event()) + .collect(), + }) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/joined_members` +/// +/// Lists all members of a room. +/// +/// - The sender user must be in the room +/// - TODO: An appservice just needs a puppet joined +pub async fn joined_members_route( + body: Ruma<joined_members::v3::Request>, +) -> Result<joined_members::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You aren't a member of the room.", + )); + } + + let mut joined = BTreeMap::new(); + for user_id in services() + .rooms + .state_cache + .room_members(&body.room_id) + .filter_map(|r| r.ok()) + { + let display_name = services().users.displayname(&user_id)?; + let avatar_url = services().users.avatar_url(&user_id)?; + + joined.insert( + user_id, + joined_members::v3::RoomMember { + display_name, + avatar_url, + }, + ); + } + + Ok(joined_members::v3::Response { joined }) +} + +async fn join_room_by_id_helper( + sender_user: Option<&UserId>, + room_id: &RoomId, + reason: Option<String>, + servers: &[OwnedServerName], + _third_party_signed: Option<&ThirdPartySigned>, +) -> Result<join_room_by_id::v3::Response> { + let sender_user = sender_user.expect("user is authenticated"); + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Ask a remote server if we are not participating in this room + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), room_id)? + { + let (make_join_response, remote_server) = + make_join_request(sender_user, room_id, servers).await?; + + let room_version_id = match make_join_response.room_version { + Some(room_version) + if services() + .globals + .supported_room_versions() + .contains(&room_version) => + { + room_version + } + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let mut join_event_stub: CanonicalJsonObject = + serde_json::from_str(make_join_response.event.get()).map_err(|_| { + Error::BadServerResponse("Invalid make_join event json received from server.") + })?; + + let join_authorized_via_users_server = join_event_stub + .get("content") + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + + // TODO: Is origin needed? + join_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + join_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + join_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason, + join_authorized_via_users_server, + }) + .expect("event is valid, we just created it"), + ); + + // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + join_event_stub.remove("event_id"); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut join_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); + + // Generate event id + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&join_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + ); + let event_id = <&EventId>::try_from(event_id.as_str()) + .expect("ruma's reference hashes are valid event ids"); + + // Add event_id back + join_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + // It has enough fields to be called a proper event now + let mut join_event = join_event_stub; + + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + }, + ) + .await?; + + if let Some(signed_raw) = &send_join_response.room_state.event { + let (signed_event_id, signed_value) = + match gen_event_id_canonical_json(signed_raw, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + } + }; + + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); + } + + if let Ok(signature) = signed_value["signatures"] + .as_object() + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent invalid signatures type", + )) + .and_then(|e| { + e.get(remote_server.as_str()).ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server did not send its signature", + )) + }) + { + join_event + .get_mut("signatures") + .expect("we created a valid pdu") + .as_object_mut() + .expect("we created a valid pdu") + .insert(remote_server.to_string(), signature.clone()); + } else { + warn!("Server {} sent invalid sendjoin event", remote_server); + } + } + + services().rooms.short.get_or_create_shortroomid(room_id)?; + + let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) + .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; + + let mut state = HashMap::new(); + let pub_key_map = RwLock::new(BTreeMap::new()); + + services() + .rooms + .event_handler + .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) + .await?; + + for result in send_join_response + .room_state + .state + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result { + Ok(t) => t, + Err(_) => continue, + }; + + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { + warn!("{:?}: {}", value, e); + Error::BadServerResponse("Invalid PDU in send_join response.") + })?; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + if let Some(state_key) = &pdu.state_key { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + state.insert(shortstatekey, pdu.event_id.clone()); + } + } + + for result in send_join_response + .room_state + .auth_chain + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result { + Ok(t) => t, + Err(_) => continue, + }; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + } + + if !state_res::event_auth::auth_check( + &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), + &parsed_join_pdu, + None::<PduEvent>, // TODO: third party invite + |k, s| { + services() + .rooms + .timeline + .get_pdu( + state.get( + &services() + .rooms + .short + .get_or_create_shortstatekey(&k.to_string().into(), s) + .ok()?, + )?, + ) + .ok()? + }, + ) + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? + { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth check failed", + )); + } + + let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( + room_id, + state + .into_iter() + .map(|(k, id)| { + services() + .rooms + .state_compressor + .compress_state_event(k, &id) + }) + .collect::<Result<_>>()?, + )?; + + services() + .rooms + .state + .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .await?; + + services().rooms.state_cache.update_joined_count(room_id)?; + + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + + services().rooms.timeline.append_pdu( + &parsed_join_pdu, + join_event, + vec![(*parsed_join_pdu.event_id).to_owned()], + &state_lock, + )?; + + // We set the room state after inserting the pdu, so that we never have a moment in time + // where events in the current room state do not exist + services() + .rooms + .state + .set_room_state(room_id, statehash_after_join, &state_lock)?; + } else { + let join_rules_event = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomJoinRules, + "", + )?; + let power_levels_event = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomPowerLevels, + "", + )?; + + let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose()?; + let power_levels_event_content: Option<RoomPowerLevelsEventContent> = power_levels_event + .as_ref() + .map(|power_levels_event| { + serde_json::from_str(power_levels_event.content.get()).map_err(|e| { + warn!("Invalid power levels event: {}", e); + Error::bad_database("Invalid power levels event in db.") + }) + }) + .transpose()?; + + let restriction_rooms = match join_rules_event_content { + Some(RoomJoinRulesEventContent { + join_rule: JoinRule::Restricted(restricted), + }) + | Some(RoomJoinRulesEventContent { + join_rule: JoinRule::KnockRestricted(restricted), + }) => restricted + .allow + .into_iter() + .filter_map(|a| match a { + AllowRule::RoomMembership(r) => Some(r.room_id), + _ => None, + }) + .collect(), + _ => Vec::new(), + }; + + let authorized_user = restriction_rooms + .iter() + .find_map(|restriction_room_id| { + if !services() + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + .ok()? + { + return None; + } + let authorized_user = power_levels_event_content + .as_ref() + .and_then(|c| { + c.users + .iter() + .filter(|(uid, i)| { + uid.server_name() == services().globals.server_name() + && **i > ruma::int!(0) + && services() + .rooms + .state_cache + .is_joined(uid, restriction_room_id) + .unwrap_or(false) + }) + .max_by_key(|(_, i)| *i) + .map(|(u, _)| u.to_owned()) + }) + .or_else(|| { + // TODO: Check here if user is actually allowed to invite. Currently the auth + // check will just fail in this case. + services() + .rooms + .state_cache + .room_members(restriction_room_id) + .filter_map(|r| r.ok()) + .find(|uid| uid.server_name() == services().globals.server_name()) + }); + Some(authorized_user) + }) + .flatten(); + + let event = RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server: authorized_user, + }; + + // Try normal join first + let error = match services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + ) { + Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), + Err(e) => e, + }; + + if !restriction_rooms.is_empty() { + // We couldn't do the join locally, maybe federation can help to satisfy the restricted + // join requirements + let (make_join_response, remote_server) = + make_join_request(sender_user, room_id, servers).await?; + + let room_version_id = match make_join_response.room_version { + Some(room_version_id) + if services() + .globals + .supported_room_versions() + .contains(&room_version_id) => + { + room_version_id + } + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + let mut join_event_stub: CanonicalJsonObject = + serde_json::from_str(make_join_response.event.get()).map_err(|_| { + Error::BadServerResponse("Invalid make_join event json received from server.") + })?; + let join_authorized_via_users_server = join_event_stub + .get("content") + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + // TODO: Is origin needed? + join_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + join_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + join_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason, + join_authorized_via_users_server, + }) + .expect("event is valid, we just created it"), + ); + + // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + join_event_stub.remove("event_id"); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut join_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); + + // Generate event id + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&join_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + ); + let event_id = <&EventId>::try_from(event_id.as_str()) + .expect("ruma's reference hashes are valid event ids"); + + // Add event_id back + join_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + // It has enough fields to be called a proper event now + let join_event = join_event_stub; + + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + }, + ) + .await?; + + if let Some(signed_raw) = send_join_response.room_state.event { + let (signed_event_id, signed_value) = + match gen_event_id_canonical_json(&signed_raw, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + } + }; + + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); + } + + drop(state_lock); + let pub_key_map = RwLock::new(BTreeMap::new()); + services() + .rooms + .event_handler + .handle_incoming_pdu( + &remote_server, + &signed_event_id, + room_id, + signed_value, + true, + &pub_key_map, + ) + .await?; + } else { + return Err(error); + } + } else { + return Err(error); + } + } + + Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) +} + +async fn make_join_request( + sender_user: &UserId, + room_id: &RoomId, + servers: &[OwnedServerName], +) -> Result<( + federation::membership::prepare_join_event::v1::Response, + OwnedServerName, +)> { + let mut make_join_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in joining.", + )); + + for remote_server in servers { + if remote_server == services().globals.server_name() { + continue; + } + let make_join_response = services() + .sending + .send_federation_request( + remote_server, + federation::membership::prepare_join_event::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services().globals.supported_room_versions(), + }, + ) + .await; + + make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); + + if make_join_response_and_server.is_ok() { + break; + } + } + + make_join_response_and_server +} + +fn validate_and_add_event_id( + pdu: &RawJsonValue, + room_version: &RoomVersionId, + pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, +) -> Result<(OwnedEventId, CanonicalJsonObject)> { + let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&value, room_version) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + let back_off = |id| match services() + .globals + .bad_event_ratelimiter + .write() + .unwrap() + .entry(id) + { + Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; + + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&event_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } + + if let Err(e) = ruma::signatures::verify_event( + &*pub_key_map + .read() + .map_err(|_| Error::bad_database("RwLock is poisoned."))?, + &value, + room_version, + ) { + warn!("Event {} failed verification {:?} {}", event_id, pdu, e); + back_off(event_id); + return Err(Error::BadServerResponse("Event failed verification.")); + } + + value.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + Ok((event_id, value)) +} + +pub(crate) async fn invite_helper<'a>( + sender_user: &UserId, + user_id: &UserId, + room_id: &RoomId, + reason: Option<String>, + is_direct: bool, +) -> Result<()> { + if user_id.server_name() != services().globals.server_name() { + let (pdu, pdu_json, invite_room_state) = { + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + let content = to_raw_value(&RoomMemberEventContent { + avatar_url: None, + displayname: None, + is_direct: Some(is_direct), + membership: MembershipState::Invite, + third_party_invite: None, + blurhash: None, + reason, + join_authorized_via_users_server: None, + }) + .expect("member event is valid value"); + + let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + PduBuilder { + event_type: RoomEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + )?; + + let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; + + drop(state_lock); + + (pdu, pdu_json, invite_room_state) + }; + + let room_version_id = services().rooms.state.get_room_version(room_id)?; + + let response = services() + .sending + .send_federation_request( + user_id.server_name(), + create_invite::v2::Request { + room_id: room_id.to_owned(), + event_id: (*pdu.event_id).to_owned(), + room_version: room_version_id.clone(), + event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), + invite_room_state, + }, + ) + .await?; + + let pub_key_map = RwLock::new(BTreeMap::new()); + + // We do not add the event_id field to the pdu here because of signature and hashes checks + let (event_id, value) = match gen_event_id_canonical_json(&response.event, &room_version_id) + { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + } + }; + + if *pdu.event_id != *event_id { + warn!("Server {} changed invite event, that's not allowed in the spec: ours: {:?}, theirs: {:?}", user_id.server_name(), pdu_json, value); + } + + let origin: OwnedServerName = serde_json::from_value( + serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event needs an origin field.", + ))?) + .expect("CanonicalJson is valid json value"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; + + let pdu_id: Vec<u8> = services() + .rooms + .event_handler + .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) + .await? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; + + // Bind to variable because of lifetimes + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(|r| r.ok()) + .filter(|server| &**server != services().globals.server_name()); + + services().sending.send_pdu(servers, &pdu_id)?; + + return Ok(()); + } + + if !services() + .rooms + .state_cache + .is_joined(sender_user, room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: services().users.displayname(user_id)?, + avatar_url: services().users.avatar_url(user_id)?, + is_direct: Some(is_direct), + third_party_invite: None, + blurhash: services().users.blurhash(user_id)?, + reason, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + )?; + + drop(state_lock); + + Ok(()) +} + +// Make a user leave all their joined rooms +pub async fn leave_all_rooms(user_id: &UserId) -> Result<()> { + let all_rooms = services() + .rooms + .state_cache + .rooms_joined(user_id) + .chain( + services() + .rooms + .state_cache + .rooms_invited(user_id) + .map(|t| t.map(|(r, _)| r)), + ) + .collect::<Vec<_>>(); + + for room_id in all_rooms { + let room_id = match room_id { + Ok(room_id) => room_id, + Err(_) => continue, + }; + + let _ = leave_room(user_id, &room_id, None).await; + } + + Ok(()) +} + +pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> { + // Ask a remote server if we don't have this room + if !services().rooms.metadata.exists(room_id)? + && room_id.server_name() != services().globals.server_name() + { + if let Err(e) = remote_leave_room(user_id, room_id).await { + warn!("Failed to leave room {} remotely: {}", user_id, e); + // Don't tell the client about this error + } + + let last_state = services() + .rooms + .state_cache + .invite_state(user_id, room_id)? + .map_or_else( + || services().rooms.state_cache.left_state(user_id, room_id), + |s| Ok(Some(s)), + )?; + + // We always drop the invite, we can't rely on other servers + services().rooms.state_cache.update_membership( + room_id, + user_id, + MembershipState::Leave, + user_id, + last_state, + true, + )?; + } else { + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + let member_event = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomMember, + user_id.as_str(), + )?; + + // Fix for broken rooms + let member_event = match member_event { + None => { + error!("Trying to leave a room you are not a member of."); + + services().rooms.state_cache.update_membership( + room_id, + user_id, + MembershipState::Leave, + user_id, + None, + true, + )?; + return Ok(()); + } + Some(e) => e, + }; + + let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; + + event.membership = MembershipState::Leave; + event.reason = reason; + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + user_id, + room_id, + &state_lock, + )?; + } + + Ok(()) +} + +async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut make_leave_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in leaving.", + )); + + let invite_state = services() + .rooms + .state_cache + .invite_state(user_id, room_id)? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "User is not invited.", + ))?; + + let servers: HashSet<_> = invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(); + + for remote_server in servers { + let make_leave_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::prepare_leave_event::v1::Request { + room_id: room_id.to_owned(), + user_id: user_id.to_owned(), + }, + ) + .await; + + make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); + + if make_leave_response_and_server.is_ok() { + break; + } + } + + let (make_leave_response, remote_server) = make_leave_response_and_server?; + + let room_version_id = match make_leave_response.room_version { + Some(version) + if services() + .globals + .supported_room_versions() + .contains(&version) => + { + version + } + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let mut leave_event_stub = serde_json::from_str::<CanonicalJsonObject>( + make_leave_response.event.get(), + ) + .map_err(|_| Error::BadServerResponse("Invalid make_leave event json received from server."))?; + + // TODO: Is origin needed? + leave_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + leave_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + leave_event_stub.remove("event_id"); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut leave_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); + + // Generate event id + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + // Add event_id back + leave_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + // It has enough fields to be called a proper event now + let leave_event = leave_event_stub; + + services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_leave_event::v2::Request { + room_id: room_id.to_owned(), + event_id, + pdu: PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), + }, + ) + .await?; + + Ok(()) +} diff --git a/src/client_server/message.rs b/src/api/client_server/message.rs index 1348132..6ad0751 100644 --- a/src/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; +use crate::{service::pdu::PduBuilder, services, utils, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -19,14 +19,14 @@ use std::{ /// - The only requirement for the content is that it has to be valid json /// - Tries to send the event into the room, auth rules will determine if it is allowed pub async fn send_message_event_route( - db: DatabaseGuard, - body: Ruma<send_message_event::v3::IncomingRequest>, + body: Ruma<send_message_event::v3::Request>, ) -> Result<send_message_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -37,7 +37,7 @@ pub async fn send_message_event_route( // Forbid m.room.encrypted if encryption is disabled if RoomEventType::RoomEncrypted == body.event_type.to_string().into() - && !db.globals.allow_encryption() + && !services().globals.allow_encryption() { return Err(Error::BadRequest( ErrorKind::Forbidden, @@ -47,7 +47,8 @@ pub async fn send_message_event_route( // Check if this is a new transaction id if let Some(response) = - db.transaction_ids + services() + .transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? { // The client might have sent a txnid of the /sendToDevice endpoint @@ -69,7 +70,7 @@ pub async fn send_message_event_route( let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: body.event_type.to_string().into(), content: serde_json::from_str(body.body.body.json().get()) @@ -80,11 +81,10 @@ pub async fn send_message_event_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; - db.transaction_ids.add_txnid( + services().transaction_ids.add_txnid( sender_user, sender_device, &body.txn_id, @@ -93,8 +93,6 @@ pub async fn send_message_event_route( drop(state_lock); - db.flush()?; - Ok(send_message_event::v3::Response::new( (*event_id).to_owned(), )) @@ -107,13 +105,16 @@ pub async fn send_message_event_route( /// - Only works if the user is joined (TODO: always allow, but only show events where the user was /// joined, depending on history_visibility) pub async fn get_message_events_route( - db: DatabaseGuard, - body: Ruma<get_message_events::v3::IncomingRequest>, + body: Ruma<get_message_events::v3::Request>, ) -> Result<get_message_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -126,15 +127,19 @@ pub async fn get_message_events_route( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from` value."))?, None => match body.dir { - get_message_events::v3::Direction::Forward => 0, - get_message_events::v3::Direction::Backward => u64::MAX, + ruma::api::client::Direction::Forward => 0, + ruma::api::client::Direction::Backward => u64::MAX, }, }; let to = body.to.as_ref().map(|t| t.parse()); - db.rooms - .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?; + services().rooms.lazy_loading.lazy_load_confirm_delivery( + sender_user, + sender_device, + &body.room_id, + from, + )?; // Use limit or else 10 let limit = body.limit.try_into().map_or(10_usize, |l: u32| l as usize); @@ -146,14 +151,17 @@ pub async fn get_message_events_route( let mut lazy_loaded = HashSet::new(); match body.dir { - get_message_events::v3::Direction::Forward => { - let events_after: Vec<_> = db + ruma::api::client::Direction::Forward => { + let events_after: Vec<_> = services() .rooms + .timeline .pdus_after(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { - db.rooms + services() + .rooms + .timeline .pdu_count(&pdu_id) .map(|pdu_count| (pdu_count, pdu)) .ok() @@ -162,7 +170,10 @@ pub async fn get_message_events_route( .collect(); for (_, event) in &events_after { - if !db.rooms.lazy_load_was_sent_before( + /* TODO: Remove this when these are resolved: + * https://github.com/vector-im/element-android/issues/3417 + * https://github.com/vector-im/element-web/issues/21034 + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -170,6 +181,8 @@ pub async fn get_message_events_route( )? { lazy_loaded.insert(event.sender.clone()); } + */ + lazy_loaded.insert(event.sender.clone()); } next_token = events_after.last().map(|(count, _)| count).copied(); @@ -183,14 +196,17 @@ pub async fn get_message_events_route( resp.end = next_token.map(|count| count.to_string()); resp.chunk = events_after; } - get_message_events::v3::Direction::Backward => { - let events_before: Vec<_> = db + ruma::api::client::Direction::Backward => { + let events_before: Vec<_> = services() .rooms + .timeline .pdus_until(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { - db.rooms + services() + .rooms + .timeline .pdu_count(&pdu_id) .map(|pdu_count| (pdu_count, pdu)) .ok() @@ -199,7 +215,10 @@ pub async fn get_message_events_route( .collect(); for (_, event) in &events_before { - if !db.rooms.lazy_load_was_sent_before( + /* TODO: Remove this when these are resolved: + * https://github.com/vector-im/element-android/issues/3417 + * https://github.com/vector-im/element-web/issues/21034 + if !services().rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -207,6 +226,8 @@ pub async fn get_message_events_route( )? { lazy_loaded.insert(event.sender.clone()); } + */ + lazy_loaded.insert(event.sender.clone()); } next_token = events_before.last().map(|(count, _)| count).copied(); @@ -224,16 +245,19 @@ pub async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { - if let Some(member_event) = - db.rooms - .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? - { + if let Some(member_event) = services().rooms.state_accessor.room_state_get( + &body.room_id, + &StateEventType::RoomMember, + ll_id.as_str(), + )? { resp.state.push(member_event.to_state_event()); } } + // TODO: enable again when we are sure clients can handle it + /* if let Some(next_token) = next_token { - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_loading.lazy_load_mark_sent( sender_user, sender_device, &body.room_id, @@ -241,6 +265,7 @@ pub async fn get_message_events_route( next_token, ); } + */ Ok(resp) } diff --git a/src/client_server/mod.rs b/src/api/client_server/mod.rs index 65b7a10..6ed17e7 100644 --- a/src/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -63,6 +63,6 @@ pub use user_directory::*; pub use voip::*; pub const DEVICE_ID_LENGTH: usize = 10; -pub const TOKEN_LENGTH: usize = 256; -pub const SESSION_ID_LENGTH: usize = 256; +pub const TOKEN_LENGTH: usize = 32; +pub const SESSION_ID_LENGTH: usize = 32; pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15; diff --git a/src/client_server/presence.rs b/src/api/client_server/presence.rs index 773fef4..ef88d1a 100644 --- a/src/client_server/presence.rs +++ b/src/api/client_server/presence.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils, Result, Ruma}; +use crate::{services, utils, Result, Ruma}; use ruma::api::client::presence::{get_presence, set_presence}; use std::time::Duration; @@ -6,22 +6,21 @@ use std::time::Duration; /// /// Sets the presence state of the sender user. pub async fn set_presence_route( - db: DatabaseGuard, - body: Ruma<set_presence::v3::IncomingRequest>, + body: Ruma<set_presence::v3::Request>, ) -> Result<set_presence::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for room_id in db.rooms.rooms_joined(sender_user) { + for room_id in services().rooms.state_cache.rooms_joined(sender_user) { let room_id = room_id?; - db.rooms.edus.update_presence( + services().rooms.edus.presence.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -32,12 +31,9 @@ pub async fn set_presence_route( }, sender: sender_user.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_presence::v3::Response {}) } @@ -47,22 +43,23 @@ pub async fn set_presence_route( /// /// - Only works if you share a room with the user pub async fn get_presence_route( - db: DatabaseGuard, - body: Ruma<get_presence::v3::IncomingRequest>, + body: Ruma<get_presence::v3::Request>, ) -> Result<get_presence::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut presence_event = None; - for room_id in db + for room_id in services() .rooms + .user .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { let room_id = room_id?; - if let Some(presence) = db + if let Some(presence) = services() .rooms .edus + .presence .get_last_presence_event(sender_user, &room_id)? { presence_event = Some(presence); diff --git a/src/client_server/profile.rs b/src/api/client_server/profile.rs index acea19f..6400e89 100644 --- a/src/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; +use crate::{service::pdu::PduBuilder, services, utils, Error, Result, Ruma}; use ruma::{ api::{ client::{ @@ -20,17 +20,18 @@ use std::sync::Arc; /// /// - Also makes sure other users receive the update using presence EDUs pub async fn set_displayname_route( - db: DatabaseGuard, - body: Ruma<set_display_name::v3::IncomingRequest>, + body: Ruma<set_display_name::v3::Request>, ) -> Result<set_display_name::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.users + services() + .users .set_displayname(sender_user, body.displayname.clone())?; // Send a new membership event and presence update into all joined rooms - let all_rooms_joined: Vec<_> = db + let all_rooms_joined: Vec<_> = services() .rooms + .state_cache .rooms_joined(sender_user) .filter_map(|r| r.ok()) .map(|room_id| { @@ -40,7 +41,9 @@ pub async fn set_displayname_route( content: to_raw_value(&RoomMemberEventContent { displayname: body.displayname.clone(), ..serde_json::from_str( - db.rooms + services() + .rooms + .state_accessor .room_state_get( &room_id, &StateEventType::RoomMember, @@ -70,7 +73,8 @@ pub async fn set_displayname_route( for (pdu_builder, room_id) in all_rooms_joined { let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -79,19 +83,22 @@ pub async fn set_displayname_route( ); let state_lock = mutex_state.lock().await; - let _ = db - .rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); + let _ = services().rooms.timeline.build_and_append_pdu( + pdu_builder, + sender_user, + &room_id, + &state_lock, + ); // Presence update - db.rooms.edus.update_presence( + services().rooms.edus.presence.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -102,12 +109,9 @@ pub async fn set_displayname_route( }, sender: sender_user.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_display_name::v3::Response {}) } @@ -117,18 +121,16 @@ pub async fn set_displayname_route( /// /// - If user is on another server: Fetches displayname over federation pub async fn get_displayname_route( - db: DatabaseGuard, - body: Ruma<get_display_name::v3::IncomingRequest>, + body: Ruma<get_display_name::v3::Request>, ) -> Result<get_display_name::v3::Response> { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { - user_id: &body.user_id, - field: Some(&ProfileField::DisplayName), + user_id: body.user_id.clone(), + field: Some(ProfileField::DisplayName), }, ) .await?; @@ -139,7 +141,7 @@ pub async fn get_displayname_route( } Ok(get_display_name::v3::Response { - displayname: db.users.displayname(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, }) } @@ -149,19 +151,22 @@ pub async fn get_displayname_route( /// /// - Also makes sure other users receive the update using presence EDUs pub async fn set_avatar_url_route( - db: DatabaseGuard, - body: Ruma<set_avatar_url::v3::IncomingRequest>, + body: Ruma<set_avatar_url::v3::Request>, ) -> Result<set_avatar_url::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.users + services() + .users .set_avatar_url(sender_user, body.avatar_url.clone())?; - db.users.set_blurhash(sender_user, body.blurhash.clone())?; + services() + .users + .set_blurhash(sender_user, body.blurhash.clone())?; // Send a new membership event and presence update into all joined rooms - let all_joined_rooms: Vec<_> = db + let all_joined_rooms: Vec<_> = services() .rooms + .state_cache .rooms_joined(sender_user) .filter_map(|r| r.ok()) .map(|room_id| { @@ -171,7 +176,9 @@ pub async fn set_avatar_url_route( content: to_raw_value(&RoomMemberEventContent { avatar_url: body.avatar_url.clone(), ..serde_json::from_str( - db.rooms + services() + .rooms + .state_accessor .room_state_get( &room_id, &StateEventType::RoomMember, @@ -201,7 +208,8 @@ pub async fn set_avatar_url_route( for (pdu_builder, room_id) in all_joined_rooms { let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -210,19 +218,22 @@ pub async fn set_avatar_url_route( ); let state_lock = mutex_state.lock().await; - let _ = db - .rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); + let _ = services().rooms.timeline.build_and_append_pdu( + pdu_builder, + sender_user, + &room_id, + &state_lock, + ); // Presence update - db.rooms.edus.update_presence( + services().rooms.edus.presence.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -233,12 +244,9 @@ pub async fn set_avatar_url_route( }, sender: sender_user.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_avatar_url::v3::Response {}) } @@ -248,18 +256,16 @@ pub async fn set_avatar_url_route( /// /// - If user is on another server: Fetches avatar_url and blurhash over federation pub async fn get_avatar_url_route( - db: DatabaseGuard, - body: Ruma<get_avatar_url::v3::IncomingRequest>, + body: Ruma<get_avatar_url::v3::Request>, ) -> Result<get_avatar_url::v3::Response> { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { - user_id: &body.user_id, - field: Some(&ProfileField::AvatarUrl), + user_id: body.user_id.clone(), + field: Some(ProfileField::AvatarUrl), }, ) .await?; @@ -271,8 +277,8 @@ pub async fn get_avatar_url_route( } Ok(get_avatar_url::v3::Response { - avatar_url: db.users.avatar_url(&body.user_id)?, - blurhash: db.users.blurhash(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, }) } @@ -282,17 +288,15 @@ pub async fn get_avatar_url_route( /// /// - If user is on another server: Fetches profile over federation pub async fn get_profile_route( - db: DatabaseGuard, - body: Ruma<get_profile::v3::IncomingRequest>, + body: Ruma<get_profile::v3::Request>, ) -> Result<get_profile::v3::Response> { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { - user_id: &body.user_id, + user_id: body.user_id.clone(), field: None, }, ) @@ -305,7 +309,7 @@ pub async fn get_profile_route( }); } - if !db.users.exists(&body.user_id)? { + if !services().users.exists(&body.user_id)? { // Return 404 if this user doesn't exist return Err(Error::BadRequest( ErrorKind::NotFound, @@ -314,8 +318,8 @@ pub async fn get_profile_route( } Ok(get_profile::v3::Response { - avatar_url: db.users.avatar_url(&body.user_id)?, - blurhash: db.users.blurhash(&body.user_id)?, - displayname: db.users.displayname(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, }) } diff --git a/src/client_server/push.rs b/src/api/client_server/push.rs index dc45ea0..b044138 100644 --- a/src/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -1,27 +1,26 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, push::{ delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions, - set_pushrule_enabled, RuleKind, + set_pushrule_enabled, RuleKind, RuleScope, }, }, events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, - push::{ConditionalPushRuleInit, PatternedPushRuleInit, SimplePushRuleInit}, + push::{ConditionalPushRuleInit, NewPushRule, PatternedPushRuleInit, SimplePushRuleInit}, }; /// # `GET /_matrix/client/r0/pushrules` /// /// Retrieves the push rules event for this user. pub async fn get_pushrules_all_route( - db: DatabaseGuard, body: Ruma<get_pushrules_all::v3::Request>, ) -> Result<get_pushrules_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -33,8 +32,12 @@ pub async fn get_pushrules_all_route( "PushRules event not found.", ))?; + let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; + Ok(get_pushrules_all::v3::Response { - global: event.content.global, + global: account_data.global, }) } @@ -42,12 +45,11 @@ pub async fn get_pushrules_all_route( /// /// Retrieves a single specified push rule for this user. pub async fn get_pushrule_route( - db: DatabaseGuard, - body: Ruma<get_pushrule::v3::IncomingRequest>, + body: Ruma<get_pushrule::v3::Request>, ) -> Result<get_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -59,7 +61,11 @@ pub async fn get_pushrule_route( "PushRules event not found.", ))?; - let global = event.content.global; + let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; + + let global = account_data.global; let rule = match body.kind { RuleKind::Override => global .override_ @@ -98,20 +104,19 @@ pub async fn get_pushrule_route( /// /// Creates a single specified push rule for this user. pub async fn set_pushrule_route( - db: DatabaseGuard, - body: Ruma<set_pushrule::v3::IncomingRequest>, + body: Ruma<set_pushrule::v3::Request>, ) -> Result<set_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - if body.scope != "global" { + if body.scope != RuleScope::Global { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Scopes other than 'global' are not supported.", )); } - let mut event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -123,79 +128,78 @@ pub async fn set_pushrule_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; - match body.kind { - RuleKind::Override => { + let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; + match body.rule { + NewPushRule::Override(rule) => { global.override_.replace( ConditionalPushRuleInit { - actions: body.actions, + actions: rule.actions, default: false, enabled: true, - rule_id: body.rule_id, - conditions: body.conditions, + rule_id: rule.rule_id, + conditions: rule.conditions, } .into(), ); } - RuleKind::Underride => { + NewPushRule::Underride(rule) => { global.underride.replace( ConditionalPushRuleInit { - actions: body.actions, + actions: rule.actions, default: false, enabled: true, - rule_id: body.rule_id, - conditions: body.conditions, + rule_id: rule.rule_id, + conditions: rule.conditions, } .into(), ); } - RuleKind::Sender => { + NewPushRule::Sender(rule) => { global.sender.replace( SimplePushRuleInit { - actions: body.actions, + actions: rule.actions, default: false, enabled: true, - rule_id: body.rule_id, + rule_id: rule.rule_id, } .into(), ); } - RuleKind::Room => { + NewPushRule::Room(rule) => { global.room.replace( SimplePushRuleInit { - actions: body.actions, + actions: rule.actions, default: false, enabled: true, - rule_id: body.rule_id, + rule_id: rule.rule_id, } .into(), ); } - RuleKind::Content => { + NewPushRule::Content(rule) => { global.content.replace( PatternedPushRuleInit { - actions: body.actions, + actions: rule.actions, default: false, enabled: true, - rule_id: body.rule_id, - pattern: body.pattern.unwrap_or_default(), + rule_id: rule.rule_id, + pattern: rule.pattern, } .into(), ); } - _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, - &db.globals, + &serde_json::to_value(account_data).expect("to json value always works"), )?; - db.flush()?; - Ok(set_pushrule::v3::Response {}) } @@ -203,19 +207,18 @@ pub async fn set_pushrule_route( /// /// Gets the actions of a single specified push rule for this user. pub async fn get_pushrule_actions_route( - db: DatabaseGuard, - body: Ruma<get_pushrule_actions::v3::IncomingRequest>, + body: Ruma<get_pushrule_actions::v3::Request>, ) -> Result<get_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != "global" { + if body.scope != RuleScope::Global { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Scopes other than 'global' are not supported.", )); } - let mut event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -227,7 +230,11 @@ pub async fn get_pushrule_actions_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; + + let global = account_data.global; let actions = match body.kind { RuleKind::Override => global .override_ @@ -252,8 +259,6 @@ pub async fn get_pushrule_actions_route( _ => None, }; - db.flush()?; - Ok(get_pushrule_actions::v3::Response { actions: actions.unwrap_or_default(), }) @@ -263,19 +268,18 @@ pub async fn get_pushrule_actions_route( /// /// Sets the actions of a single specified push rule for this user. pub async fn set_pushrule_actions_route( - db: DatabaseGuard, - body: Ruma<set_pushrule_actions::v3::IncomingRequest>, + body: Ruma<set_pushrule_actions::v3::Request>, ) -> Result<set_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != "global" { + if body.scope != RuleScope::Global { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Scopes other than 'global' are not supported.", )); } - let mut event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -287,7 +291,10 @@ pub async fn set_pushrule_actions_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; match body.kind { RuleKind::Override => { if let Some(mut rule) = global.override_.get(body.rule_id.as_str()).cloned() { @@ -322,16 +329,13 @@ pub async fn set_pushrule_actions_route( _ => {} }; - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, - &db.globals, + &serde_json::to_value(account_data).expect("to json value always works"), )?; - db.flush()?; - Ok(set_pushrule_actions::v3::Response {}) } @@ -339,19 +343,18 @@ pub async fn set_pushrule_actions_route( /// /// Gets the enabled status of a single specified push rule for this user. pub async fn get_pushrule_enabled_route( - db: DatabaseGuard, - body: Ruma<get_pushrule_enabled::v3::IncomingRequest>, + body: Ruma<get_pushrule_enabled::v3::Request>, ) -> Result<get_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != "global" { + if body.scope != RuleScope::Global { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Scopes other than 'global' are not supported.", )); } - let mut event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -363,7 +366,10 @@ pub async fn get_pushrule_enabled_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = account_data.content.global; let enabled = match body.kind { RuleKind::Override => global .override_ @@ -393,8 +399,6 @@ pub async fn get_pushrule_enabled_route( _ => false, }; - db.flush()?; - Ok(get_pushrule_enabled::v3::Response { enabled }) } @@ -402,19 +406,18 @@ pub async fn get_pushrule_enabled_route( /// /// Sets the enabled status of a single specified push rule for this user. pub async fn set_pushrule_enabled_route( - db: DatabaseGuard, - body: Ruma<set_pushrule_enabled::v3::IncomingRequest>, + body: Ruma<set_pushrule_enabled::v3::Request>, ) -> Result<set_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != "global" { + if body.scope != RuleScope::Global { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Scopes other than 'global' are not supported.", )); } - let mut event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -426,7 +429,10 @@ pub async fn set_pushrule_enabled_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; match body.kind { RuleKind::Override => { if let Some(mut rule) = global.override_.get(body.rule_id.as_str()).cloned() { @@ -466,16 +472,13 @@ pub async fn set_pushrule_enabled_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, - &db.globals, + &serde_json::to_value(account_data).expect("to json value always works"), )?; - db.flush()?; - Ok(set_pushrule_enabled::v3::Response {}) } @@ -483,19 +486,18 @@ pub async fn set_pushrule_enabled_route( /// /// Deletes a single specified push rule for this user. pub async fn delete_pushrule_route( - db: DatabaseGuard, - body: Ruma<delete_pushrule::v3::IncomingRequest>, + body: Ruma<delete_pushrule::v3::Request>, ) -> Result<delete_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != "global" { + if body.scope != RuleScope::Global { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Scopes other than 'global' are not supported.", )); } - let mut event: PushRulesEvent = db + let event = services() .account_data .get( None, @@ -507,7 +509,10 @@ pub async fn delete_pushrule_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; match body.kind { RuleKind::Override => { if let Some(rule) = global.override_.get(body.rule_id.as_str()).cloned() { @@ -537,16 +542,13 @@ pub async fn delete_pushrule_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, - &db.globals, + &serde_json::to_value(account_data).expect("to json value always works"), )?; - db.flush()?; - Ok(delete_pushrule::v3::Response {}) } @@ -554,13 +556,12 @@ pub async fn delete_pushrule_route( /// /// Gets all currently active pushers for the sender user. pub async fn get_pushers_route( - db: DatabaseGuard, body: Ruma<get_pushers::v3::Request>, ) -> Result<get_pushers::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: db.pusher.get_pushers(sender_user)?, + pushers: services().pusher.get_pushers(sender_user)?, }) } @@ -570,15 +571,13 @@ pub async fn get_pushers_route( /// /// - TODO: Handle `append` pub async fn set_pushers_route( - db: DatabaseGuard, body: Ruma<set_pusher::v3::Request>, ) -> Result<set_pusher::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let pusher = body.pusher.clone(); - - db.pusher.set_pusher(sender_user, pusher)?; - db.flush()?; + services() + .pusher + .set_pusher(sender_user, body.action.clone())?; Ok(set_pusher::v3::Response::default()) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs new file mode 100644 index 0000000..b12468a --- /dev/null +++ b/src/api/client_server/read_marker.rs @@ -0,0 +1,162 @@ +use crate::{services, Error, Result, Ruma}; +use ruma::{ + api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, + events::{ + receipt::{ReceiptThread, ReceiptType}, + RoomAccountDataEventType, + }, + MilliSecondsSinceUnixEpoch, +}; +use std::collections::BTreeMap; + +/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` +/// +/// Sets different types of read markers. +/// +/// - Updates fully-read account data event to `fully_read` +/// - If `read_receipt` is set: Update private marker and public read receipt EDU +pub async fn set_read_marker_route( + body: Ruma<set_read_marker::v3::Request>, +) -> Result<set_read_marker::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if let Some(fully_read) = &body.fully_read { + let fully_read_event = ruma::events::fully_read::FullyReadEvent { + content: ruma::events::fully_read::FullyReadEventContent { + event_id: fully_read.clone(), + }, + }; + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + )?; + } + + if body.private_read_receipt.is_some() || body.read_receipt.is_some() { + services() + .rooms + .user + .reset_notification_counts(sender_user, &body.room_id)?; + } + + if let Some(event) = &body.private_read_receipt { + services().rooms.edus.read_receipt.private_read_set( + &body.room_id, + sender_user, + services() + .rooms + .timeline + .get_pdu_count(event)? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event does not exist.", + ))?, + )?; + } + + if let Some(event) = &body.read_receipt { + let mut user_receipts = BTreeMap::new(); + user_receipts.insert( + sender_user.clone(), + ruma::events::receipt::Receipt { + ts: Some(MilliSecondsSinceUnixEpoch::now()), + thread: ReceiptThread::Unthreaded, + }, + ); + + let mut receipts = BTreeMap::new(); + receipts.insert(ReceiptType::Read, user_receipts); + + let mut receipt_content = BTreeMap::new(); + receipt_content.insert(event.to_owned(), receipts); + + services().rooms.edus.read_receipt.readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + )?; + } + + Ok(set_read_marker::v3::Response {}) +} + +/// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}` +/// +/// Sets private read marker and public read receipt EDU. +pub async fn create_receipt_route( + body: Ruma<create_receipt::v3::Request>, +) -> Result<create_receipt::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + if matches!( + &body.receipt_type, + create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate + ) { + services() + .rooms + .user + .reset_notification_counts(sender_user, &body.room_id)?; + } + + match body.receipt_type { + create_receipt::v3::ReceiptType::FullyRead => { + let fully_read_event = ruma::events::fully_read::FullyReadEvent { + content: ruma::events::fully_read::FullyReadEventContent { + event_id: body.event_id.clone(), + }, + }; + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + )?; + } + create_receipt::v3::ReceiptType::Read => { + let mut user_receipts = BTreeMap::new(); + user_receipts.insert( + sender_user.clone(), + ruma::events::receipt::Receipt { + ts: Some(MilliSecondsSinceUnixEpoch::now()), + thread: ReceiptThread::Unthreaded, + }, + ); + let mut receipts = BTreeMap::new(); + receipts.insert(ReceiptType::Read, user_receipts); + + let mut receipt_content = BTreeMap::new(); + receipt_content.insert(body.event_id.to_owned(), receipts); + + services().rooms.edus.read_receipt.readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + )?; + } + create_receipt::v3::ReceiptType::ReadPrivate => { + services().rooms.edus.read_receipt.private_read_set( + &body.room_id, + sender_user, + services() + .rooms + .timeline + .get_pdu_count(&body.event_id)? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event does not exist.", + ))?, + )?; + } + _ => return Err(Error::bad_database("Unsupported receipt type")), + } + + Ok(create_receipt::v3::Response {}) +} diff --git a/src/client_server/redact.rs b/src/api/client_server/redact.rs index 059e0f5..a29a561 100644 --- a/src/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::{database::DatabaseGuard, pdu::PduBuilder, Result, Ruma}; +use crate::{service::pdu::PduBuilder, services, Result, Ruma}; use ruma::{ api::client::redact::redact_event, events::{room::redaction::RoomRedactionEventContent, RoomEventType}, @@ -14,14 +14,14 @@ use serde_json::value::to_raw_value; /// /// - TODO: Handle txn id pub async fn redact_event_route( - db: DatabaseGuard, - body: Ruma<redact_event::v3::IncomingRequest>, + body: Ruma<redact_event::v3::Request>, ) -> Result<redact_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -30,7 +30,7 @@ pub async fn redact_event_route( ); let state_lock = mutex_state.lock().await; - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomRedaction, content: to_raw_value(&RoomRedactionEventContent { @@ -43,14 +43,11 @@ pub async fn redact_event_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(redact_event::v3::Response { event_id }) } diff --git a/src/client_server/report.rs b/src/api/client_server/report.rs index 14768e1..ab5027c 100644 --- a/src/client_server/report.rs +++ b/src/api/client_server/report.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils::HtmlEscape, Error, Result, Ruma}; +use crate::{services, utils::HtmlEscape, Error, Result, Ruma}; use ruma::{ api::client::{error::ErrorKind, room::report_content}, events::room::message, @@ -10,12 +10,11 @@ use ruma::{ /// Reports an inappropriate event to homeserver admins /// pub async fn report_event_route( - db: DatabaseGuard, - body: Ruma<report_content::v3::IncomingRequest>, + body: Ruma<report_content::v3::Request>, ) -> Result<report_content::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let pdu = match db.rooms.get_pdu(&body.event_id)? { + let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? { Some(pdu) => pdu, _ => { return Err(Error::BadRequest( @@ -39,7 +38,7 @@ pub async fn report_event_route( )); }; - db.admin + services().admin .send_message(message::RoomMessageEventContent::text_html( format!( "Report received from: {}\n\n\ @@ -66,7 +65,5 @@ pub async fn report_event_route( ), )); - db.flush()?; - Ok(report_content::v3::Response {}) } diff --git a/src/client_server/room.rs b/src/api/client_server/room.rs index a5b7970..830e085 100644 --- a/src/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -1,5 +1,5 @@ use crate::{ - client_server::invite_helper, database::DatabaseGuard, pdu::PduBuilder, Error, Result, Ruma, + api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma, }; use ruma::{ api::client::{ @@ -22,8 +22,8 @@ use ruma::{ RoomEventType, StateEventType, }, int, - serde::{CanonicalJsonObject, JsonObject}, - RoomAliasId, RoomId, + serde::JsonObject, + CanonicalJsonObject, OwnedRoomAliasId, RoomAliasId, RoomId, }; use serde_json::{json, value::to_raw_value}; use std::{cmp::max, collections::BTreeMap, sync::Arc}; @@ -46,19 +46,19 @@ use tracing::{info, warn}; /// - Send events implied by `name` and `topic` /// - Send invite events pub async fn create_room_route( - db: DatabaseGuard, - body: Ruma<create_room::v3::IncomingRequest>, + body: Ruma<create_room::v3::Request>, ) -> Result<create_room::v3::Response> { use create_room::v3::RoomPreset; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let room_id = RoomId::new(db.globals.server_name()); + let room_id = RoomId::new(services().globals.server_name()); - db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + services().rooms.short.get_or_create_shortroomid(&room_id)?; let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -67,9 +67,9 @@ pub async fn create_room_route( ); let state_lock = mutex_state.lock().await; - if !db.globals.allow_room_creation() + if !services().globals.allow_room_creation() && !body.from_appservice - && !db.users.is_admin(sender_user, &db.rooms, &db.globals)? + && !services().users.is_admin(sender_user)? { return Err(Error::BadRequest( ErrorKind::Forbidden, @@ -77,18 +77,24 @@ pub async fn create_room_route( )); } - let alias: Option<Box<RoomAliasId>> = + let alias: Option<OwnedRoomAliasId> = body.room_alias_name .as_ref() .map_or(Ok(None), |localpart| { // TODO: Check for invalid characters and maximum length - let alias = - RoomAliasId::parse(format!("#{}:{}", localpart, db.globals.server_name())) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.") - })?; - - if db.rooms.id_from_alias(&alias)?.is_some() { + let alias = RoomAliasId::parse(format!( + "#{}:{}", + localpart, + services().globals.server_name() + )) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; + + if services() + .rooms + .alias + .resolve_local_alias(&alias)? + .is_some() + { Err(Error::BadRequest( ErrorKind::RoomInUse, "Room alias already exists.", @@ -100,7 +106,11 @@ pub async fn create_room_route( let room_version = match body.room_version.clone() { Some(room_version) => { - if db.rooms.is_supported_version(&db, &room_version) { + if services() + .globals + .supported_room_versions() + .contains(&room_version) + { room_version } else { return Err(Error::BadRequest( @@ -109,7 +119,7 @@ pub async fn create_room_route( )); } } - None => db.globals.default_room_version(), + None => services().globals.default_room_version(), }; let content = match &body.creation_content { @@ -163,7 +173,7 @@ pub async fn create_room_route( } // 1. The room create event - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCreate, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -173,21 +183,20 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 2. Let the room creator join - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: Some(body.is_direct), third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -198,21 +207,17 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 3. Power levels // Figure out preset. We need it for preset specific events - let preset = body - .preset - .clone() - .unwrap_or_else(|| match &body.visibility { - room::Visibility::Private => RoomPreset::PrivateChat, - room::Visibility::Public => RoomPreset::PublicChat, - _ => RoomPreset::PrivateChat, // Room visibility should not be custom - }); + let preset = body.preset.clone().unwrap_or(match &body.visibility { + room::Visibility::Private => RoomPreset::PrivateChat, + room::Visibility::Public => RoomPreset::PublicChat, + _ => RoomPreset::PrivateChat, // Room visibility should not be custom + }); let mut users = BTreeMap::new(); users.insert(sender_user.clone(), int!(100)); @@ -240,7 +245,7 @@ pub async fn create_room_route( } } - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&power_levels_content) @@ -251,13 +256,12 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 4. Canonical room alias if let Some(room_alias_id) = &alias { - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCanonicalAlias, content: to_raw_value(&RoomCanonicalAliasEventContent { @@ -271,7 +275,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } @@ -279,7 +282,7 @@ pub async fn create_room_route( // 5. Events set by preset // 5.1 Join Rules - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomJoinRules, content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { @@ -294,12 +297,11 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 5.2 History Visibility - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomHistoryVisibility, content: to_raw_value(&RoomHistoryVisibilityEventContent::new( @@ -312,12 +314,11 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 5.3 Guest Access - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomGuestAccess, content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { @@ -331,7 +332,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; @@ -346,18 +346,23 @@ pub async fn create_room_route( pdu_builder.state_key.get_or_insert_with(|| "".to_owned()); // Silently skip encryption events if they are not allowed - if pdu_builder.event_type == RoomEventType::RoomEncryption && !db.globals.allow_encryption() + if pdu_builder.event_type == RoomEventType::RoomEncryption + && !services().globals.allow_encryption() { continue; } - db.rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock)?; + services().rooms.timeline.build_and_append_pdu( + pdu_builder, + sender_user, + &room_id, + &state_lock, + )?; } // 7. Events implied by name and topic if let Some(name) = &body.name { - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomName, content: to_raw_value(&RoomNameEventContent::new(Some(name.clone()))) @@ -368,13 +373,12 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } if let Some(topic) = &body.topic { - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomTopic, content: to_raw_value(&RoomTopicEventContent { @@ -387,7 +391,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } @@ -395,22 +398,20 @@ pub async fn create_room_route( // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - let _ = invite_helper(sender_user, user_id, &room_id, &db, body.is_direct).await; + let _ = invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await; } // Homeserver specific stuff if let Some(alias) = alias { - db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; + services().rooms.alias.set_alias(&alias, &room_id)?; } if body.visibility == room::Visibility::Public { - db.rooms.set_public(&room_id, true)?; + services().rooms.directory.set_public(&room_id)?; } info!("{} created a room", sender_user); - db.flush()?; - Ok(create_room::v3::Response::new(room_id)) } @@ -420,12 +421,15 @@ pub async fn create_room_route( /// /// - You have to currently be joined to the room (TODO: Respect history visibility) pub async fn get_room_event_route( - db: DatabaseGuard, - body: Ruma<get_room_event::v3::IncomingRequest>, + body: Ruma<get_room_event::v3::Request>, ) -> Result<get_room_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -433,8 +437,9 @@ pub async fn get_room_event_route( } Ok(get_room_event::v3::Response { - event: db + event: services() .rooms + .timeline .get_pdu(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))? .to_room_event(), @@ -447,12 +452,15 @@ pub async fn get_room_event_route( /// /// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable pub async fn get_room_aliases_route( - db: DatabaseGuard, - body: Ruma<aliases::v3::IncomingRequest>, + body: Ruma<aliases::v3::Request>, ) -> Result<aliases::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -460,9 +468,10 @@ pub async fn get_room_aliases_route( } Ok(aliases::v3::Response { - aliases: db + aliases: services() .rooms - .room_aliases(&body.room_id) + .alias + .local_aliases_for_room(&body.room_id) .filter_map(|a| a.ok()) .collect(), }) @@ -479,12 +488,15 @@ pub async fn get_room_aliases_route( /// - Moves local aliases /// - Modifies old room power levels to prevent users from speaking pub async fn upgrade_room_route( - db: DatabaseGuard, - body: Ruma<upgrade_room::v3::IncomingRequest>, + body: Ruma<upgrade_room::v3::Request>, ) -> Result<upgrade_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_supported_version(&db, &body.new_version) { + if !services() + .globals + .supported_room_versions() + .contains(&body.new_version) + { return Err(Error::BadRequest( ErrorKind::UnsupportedRoomVersion, "This server does not support that room version.", @@ -492,12 +504,15 @@ pub async fn upgrade_room_route( } // Create a replacement room - let replacement_room = RoomId::new(db.globals.server_name()); - db.rooms - .get_or_create_shortroomid(&replacement_room, &db.globals)?; + let replacement_room = RoomId::new(services().globals.server_name()); + services() + .rooms + .short + .get_or_create_shortroomid(&replacement_room)?; let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -508,7 +523,7 @@ pub async fn upgrade_room_route( // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further // Fail if the sender does not have the required permissions - let tombstone_event_id = db.rooms.build_and_append_pdu( + let tombstone_event_id = services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomTombstone, content: to_raw_value(&RoomTombstoneEventContent { @@ -522,14 +537,14 @@ pub async fn upgrade_room_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; // Change lock to replacement room drop(state_lock); let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -540,7 +555,9 @@ pub async fn upgrade_room_route( // Get the old room creation event let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>( - db.rooms + services() + .rooms + .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .content @@ -588,7 +605,7 @@ pub async fn upgrade_room_route( )); } - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCreate, content: to_raw_value(&create_event_content) @@ -599,21 +616,20 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; // Join the new room - db.rooms.build_and_append_pdu( + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -624,7 +640,6 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; @@ -643,12 +658,17 @@ pub async fn upgrade_room_route( // Replicate transferable state events to the new room for event_type in transferable_state_events { - let event_content = match db.rooms.room_state_get(&body.room_id, &event_type, "")? { - Some(v) => v.content.clone(), - None => continue, // Skipping missing events. - }; - - db.rooms.build_and_append_pdu( + let event_content = + match services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &event_type, "")? + { + Some(v) => v.content.clone(), + None => continue, // Skipping missing events. + }; + + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), content: event_content, @@ -658,20 +678,28 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; } // Moves any local aliases to the new room - for alias in db.rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) { - db.rooms - .set_alias(&alias, Some(&replacement_room), &db.globals)?; + for alias in services() + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .filter_map(|r| r.ok()) + { + services() + .rooms + .alias + .set_alias(&alias, &replacement_room)?; } // Get the old room power levels let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - db.rooms + services() + .rooms + .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .content @@ -685,7 +713,7 @@ pub async fn upgrade_room_route( power_levels_event_content.invite = new_level; // Modify the power levels in the old room to prevent sending of events and inviting new users - let _ = db.rooms.build_and_append_pdu( + let _ = services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&power_levels_event_content) @@ -696,14 +724,11 @@ pub async fn upgrade_room_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - // Return the replacement room id Ok(upgrade_room::v3::Response { replacement_room }) } diff --git a/src/client_server/search.rs b/src/api/client_server/search.rs index 686e3b5..51255d5 100644 --- a/src/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::api::client::{ error::ErrorKind, search::search_events::{ @@ -15,8 +15,7 @@ use std::collections::BTreeMap; /// /// - Only works if the user is currently joined to the room (TODO: Respect history visibility) pub async fn search_events_route( - db: DatabaseGuard, - body: Ruma<search_events::v3::IncomingRequest>, + body: Ruma<search_events::v3::Request>, ) -> Result<search_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -24,7 +23,9 @@ pub async fn search_events_route( let filter = &search_criteria.filter; let room_ids = filter.rooms.clone().unwrap_or_else(|| { - db.rooms + services() + .rooms + .state_cache .rooms_joined(sender_user) .filter_map(|r| r.ok()) .collect() @@ -35,15 +36,20 @@ pub async fn search_events_route( let mut searches = Vec::new(); for room_id in room_ids { - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, &room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if let Some(search) = db + if let Some(search) = services() .rooms + .search .search_pdus(&room_id, &search_criteria.search_term)? { searches.push(search.0.peekable()); @@ -85,8 +91,9 @@ pub async fn search_events_route( start: None, }, rank: None, - result: db + result: services() .rooms + .timeline .get_pdu_from_id(result)? .map(|pdu| pdu.to_room_event()), }) @@ -96,7 +103,7 @@ pub async fn search_events_route( .take(limit) .collect(); - let next_batch = if results.len() < limit as usize { + let next_batch = if results.len() < limit { None } else { Some((skip + limit).to_string()) diff --git a/src/client_server/session.rs b/src/api/client_server/session.rs index c2a79ca..64c0072 100644 --- a/src/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,10 +1,10 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{services, utils, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, session::{get_login_types, login, logout, logout_all}, - uiaa::IncomingUserIdentifier, + uiaa::UserIdentifier, }, UserId, }; @@ -22,7 +22,7 @@ struct Claims { /// Get the supported login types of this server. One of these should be used as the `type` field /// when logging in. pub async fn get_login_types_route( - _body: Ruma<get_login_types::v3::IncomingRequest>, + _body: Ruma<get_login_types::v3::Request>, ) -> Result<get_login_types::v3::Response> { Ok(get_login_types::v3::Response::new(vec![ get_login_types::v3::LoginType::Password(Default::default()), @@ -40,31 +40,31 @@ pub async fn get_login_types_route( /// /// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// supported login types. -pub async fn login_route( - db: DatabaseGuard, - body: Ruma<login::v3::IncomingRequest>, -) -> Result<login::v3::Response> { +pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> { // Validate login method // TODO: Other login methods let user_id = match &body.login_info { - login::v3::IncomingLoginInfo::Password(login::v3::IncomingPassword { + login::v3::LoginInfo::Password(login::v3::Password { identifier, password, }) => { - let username = if let IncomingUserIdentifier::UserIdOrLocalpart(user_id) = identifier { + let username = if let UserIdentifier::UserIdOrLocalpart(user_id) = identifier { user_id.to_lowercase() } else { return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); }; let user_id = - UserId::parse_with_server_name(username.to_owned(), db.globals.server_name()) + UserId::parse_with_server_name(username, services().globals.server_name()) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") })?; - let hash = db.users.password_hash(&user_id)?.ok_or(Error::BadRequest( - ErrorKind::Forbidden, - "Wrong username or password.", - ))?; + let hash = services() + .users + .password_hash(&user_id)? + .ok_or(Error::BadRequest( + ErrorKind::Forbidden, + "Wrong username or password.", + ))?; if hash.is_empty() { return Err(Error::BadRequest( @@ -84,16 +84,16 @@ pub async fn login_route( user_id } - login::v3::IncomingLoginInfo::Token(login::v3::IncomingToken { token }) => { - if let Some(jwt_decoding_key) = db.globals.jwt_decoding_key() { + login::v3::LoginInfo::Token(login::v3::Token { token }) => { + if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { let token = jsonwebtoken::decode::<Claims>( token, jwt_decoding_key, &jsonwebtoken::Validation::default(), ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; - let username = token.claims.sub; - UserId::parse_with_server_name(username, db.globals.server_name()).map_err( + let username = token.claims.sub.to_lowercase(); + UserId::parse_with_server_name(username, services().globals.server_name()).map_err( |_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."), )? } else { @@ -122,15 +122,16 @@ pub async fn login_route( // Determine if device_id was provided and exists in the db for this user let device_exists = body.device_id.as_ref().map_or(false, |device_id| { - db.users + services() + .users .all_device_ids(&user_id) .any(|x| x.as_ref().map_or(false, |v| v == device_id)) }); if device_exists { - db.users.set_token(&user_id, &device_id, &token)?; + services().users.set_token(&user_id, &device_id, &token)?; } else { - db.users.create_device( + services().users.create_device( &user_id, &device_id, &token, @@ -140,14 +141,14 @@ pub async fn login_route( info!("{} logged in", user_id); - db.flush()?; - Ok(login::v3::Response { user_id, access_token: token, - home_server: Some(db.globals.server_name().to_owned()), + home_server: Some(services().globals.server_name().to_owned()), device_id, well_known: None, + refresh_token: None, + expires_in: None, }) } @@ -159,16 +160,11 @@ pub async fn login_route( /// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) /// - Forgets to-device events /// - Triggers device list updates -pub async fn logout_route( - db: DatabaseGuard, - body: Ruma<logout::v3::Request>, -) -> Result<logout::v3::Response> { +pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - db.users.remove_device(sender_user, sender_device)?; - - db.flush()?; + services().users.remove_device(sender_user, sender_device)?; Ok(logout::v3::Response::new()) } @@ -185,16 +181,13 @@ pub async fn logout_route( /// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html) /// from each device of this user. pub async fn logout_all_route( - db: DatabaseGuard, body: Ruma<logout_all::v3::Request>, ) -> Result<logout_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in db.users.all_device_ids(sender_user).flatten() { - db.users.remove_device(sender_user, &device_id)?; + for device_id in services().users.all_device_ids(sender_user).flatten() { + services().users.remove_device(sender_user, &device_id)?; } - db.flush()?; - Ok(logout_all::v3::Response::new()) } diff --git a/src/client_server/state.rs b/src/api/client_server/state.rs index 4df953c..d9c1464 100644 --- a/src/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -1,8 +1,6 @@ use std::sync::Arc; -use crate::{ - database::DatabaseGuard, pdu::PduBuilder, Database, Error, Result, Ruma, RumaResponse, -}; +use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse}; use ruma::{ api::client::{ error::ErrorKind, @@ -27,13 +25,11 @@ use ruma::{ /// - Tries to send the event into the room, auth rules will determine if it is allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_key_route( - db: DatabaseGuard, - body: Ruma<send_state_event::v3::IncomingRequest>, + body: Ruma<send_state_event::v3::Request>, ) -> Result<send_state_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let event_id = send_state_event_for_key_helper( - &db, sender_user, &body.room_id, &body.event_type, @@ -42,8 +38,6 @@ pub async fn send_state_event_for_key_route( ) .await?; - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(send_state_event::v3::Response { event_id }) } @@ -56,13 +50,12 @@ pub async fn send_state_event_for_key_route( /// - Tries to send the event into the room, auth rules will determine if it is allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_empty_key_route( - db: DatabaseGuard, - body: Ruma<send_state_event::v3::IncomingRequest>, + body: Ruma<send_state_event::v3::Request>, ) -> Result<RumaResponse<send_state_event::v3::Response>> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // Forbid m.room.encryption if encryption is disabled - if body.event_type == StateEventType::RoomEncryption && !db.globals.allow_encryption() { + if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { return Err(Error::BadRequest( ErrorKind::Forbidden, "Encryption has been disabled", @@ -70,7 +63,6 @@ pub async fn send_state_event_for_empty_key_route( } let event_id = send_state_event_for_key_helper( - &db, sender_user, &body.room_id, &body.event_type.to_string().into(), @@ -79,8 +71,6 @@ pub async fn send_state_event_for_empty_key_route( ) .await?; - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(send_state_event::v3::Response { event_id }.into()) } @@ -91,17 +81,21 @@ pub async fn send_state_event_for_empty_key_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_route( - db: DatabaseGuard, - body: Ruma<get_state_events::v3::IncomingRequest>, + body: Ruma<get_state_events::v3::Request>, ) -> Result<get_state_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services() + .rooms + .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -122,8 +116,9 @@ pub async fn get_state_events_route( } Ok(get_state_events::v3::Response { - room_state: db + room_state: services() .rooms + .state_accessor .room_state_full(&body.room_id) .await? .values() @@ -138,17 +133,21 @@ pub async fn get_state_events_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_for_key_route( - db: DatabaseGuard, - body: Ruma<get_state_events_for_key::v3::IncomingRequest>, + body: Ruma<get_state_events_for_key::v3::Request>, ) -> Result<get_state_events_for_key::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services() + .rooms + .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -168,8 +167,9 @@ pub async fn get_state_events_for_key_route( )); } - let event = db + let event = services() .rooms + .state_accessor .room_state_get(&body.room_id, &body.event_type, &body.state_key)? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -188,17 +188,21 @@ pub async fn get_state_events_for_key_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_for_empty_key_route( - db: DatabaseGuard, - body: Ruma<get_state_events_for_key::v3::IncomingRequest>, + body: Ruma<get_state_events_for_key::v3::Request>, ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services() + .rooms + .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -218,8 +222,9 @@ pub async fn get_state_events_for_empty_key_route( )); } - let event = db + let event = services() .rooms + .state_accessor .room_state_get(&body.room_id, &body.event_type, "")? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -234,7 +239,6 @@ pub async fn get_state_events_for_empty_key_route( } async fn send_state_event_for_key_helper( - db: &Database, sender: &UserId, room_id: &RoomId, event_type: &StateEventType, @@ -255,10 +259,11 @@ async fn send_state_event_for_key_helper( } for alias in aliases { - if alias.server_name() != db.globals.server_name() - || db + if alias.server_name() != services().globals.server_name() + || services() .rooms - .id_from_alias(&alias)? + .alias + .resolve_local_alias(&alias)? .filter(|room| room == room_id) // Make sure it's the right room .is_none() { @@ -272,7 +277,8 @@ async fn send_state_event_for_key_helper( } let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -281,7 +287,7 @@ async fn send_state_event_for_key_helper( ); let state_lock = mutex_state.lock().await; - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), content: serde_json::from_str(json.json().get()).expect("content is valid json"), @@ -291,7 +297,6 @@ async fn send_state_event_for_key_helper( }, sender_user, room_id, - db, &state_lock, )?; diff --git a/src/client_server/sync.rs b/src/api/client_server/sync.rs index 0c294b7..568a23c 100644 --- a/src/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -1,8 +1,8 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma, RumaResponse}; +use crate::{services, Error, Result, Ruma, RumaResponse}; use ruma::{ api::client::{ - filter::{IncomingFilterDefinition, LazyLoadOptions}, - sync::sync_events, + filter::{FilterDefinition, LazyLoadOptions}, + sync::sync_events::{self, DeviceLists, UnreadNotificationsCount}, uiaa::UiaaResponse, }, events::{ @@ -10,7 +10,7 @@ use ruma::{ RoomEventType, StateEventType, }, serde::Raw, - DeviceId, RoomId, UserId, + OwnedDeviceId, OwnedUserId, RoomId, UserId, }; use std::{ collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, @@ -55,16 +55,13 @@ use tracing::error; /// - Sync is handled in an async task, multiple requests from the same device with the same /// `since` will be cached pub async fn sync_events_route( - db: DatabaseGuard, - body: Ruma<sync_events::v3::IncomingRequest>, + body: Ruma<sync_events::v3::Request>, ) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { let sender_user = body.sender_user.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated"); let body = body.body; - let arc_db = Arc::new(db); - - let mut rx = match arc_db + let mut rx = match services() .globals .sync_receivers .write() @@ -77,7 +74,6 @@ pub async fn sync_events_route( v.insert((body.since.to_owned(), rx.clone())); tokio::spawn(sync_helper_wrapper( - Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body, @@ -93,7 +89,6 @@ pub async fn sync_events_route( o.insert((body.since.clone(), rx.clone())); tokio::spawn(sync_helper_wrapper( - Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body, @@ -127,25 +122,18 @@ pub async fn sync_events_route( } async fn sync_helper_wrapper( - db: Arc<DatabaseGuard>, - sender_user: Box<UserId>, - sender_device: Box<DeviceId>, - body: sync_events::v3::IncomingRequest, + sender_user: OwnedUserId, + sender_device: OwnedDeviceId, + body: sync_events::v3::Request, tx: Sender<Option<Result<sync_events::v3::Response>>>, ) { let since = body.since.clone(); - let r = sync_helper( - Arc::clone(&db), - sender_user.clone(), - sender_device.clone(), - body, - ) - .await; + let r = sync_helper(sender_user.clone(), sender_device.clone(), body).await; if let Ok((_, caching_allowed)) = r { if !caching_allowed { - match db + match services() .globals .sync_receivers .write() @@ -163,38 +151,34 @@ async fn sync_helper_wrapper( } } - drop(db); - let _ = tx.send(Some(r.map(|(r, _)| r))); } async fn sync_helper( - db: Arc<DatabaseGuard>, - sender_user: Box<UserId>, - sender_device: Box<DeviceId>, - body: sync_events::v3::IncomingRequest, + sender_user: OwnedUserId, + sender_device: OwnedDeviceId, + body: sync_events::v3::Request, // bool = caching allowed ) -> Result<(sync_events::v3::Response, bool), Error> { use sync_events::v3::{ - DeviceLists, Ephemeral, GlobalAccountData, IncomingFilter, InviteState, InvitedRoom, - JoinedRoom, LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State, Timeline, - ToDevice, UnreadNotificationsCount, + Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, + Presence, RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, }; // TODO: match body.set_presence { - db.rooms.edus.ping_presence(&sender_user)?; + services().rooms.edus.presence.ping_presence(&sender_user)?; // Setup watchers, so if there's no response, we can wait for them - let watcher = db.watch(&sender_user, &sender_device); + let watcher = services().globals.watch(&sender_user, &sender_device); - let next_batch = db.globals.current_count()?; + let next_batch = services().globals.current_count()?; let next_batch_string = next_batch.to_string(); // Load filter let filter = match body.filter { - None => IncomingFilterDefinition::default(), - Some(IncomingFilter::FilterDefinition(filter)) => filter, - Some(IncomingFilter::FilterId(filter_id)) => db + None => FilterDefinition::default(), + Some(Filter::FilterDefinition(filter)) => filter, + Some(Filter::FilterId(filter_id)) => services() .users .get_filter(&sender_user, &filter_id)? .unwrap_or_default(), @@ -221,12 +205,17 @@ async fn sync_helper( // Look for device list updates of this account device_list_updates.extend( - db.users - .keys_changed(&sender_user.to_string(), since, None) + services() + .users + .keys_changed(sender_user.as_ref(), since, None) .filter_map(|r| r.ok()), ); - let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>(); + let all_joined_rooms = services() + .rooms + .state_cache + .rooms_joined(&sender_user) + .collect::<Vec<_>>(); for room_id in all_joined_rooms { let room_id = room_id?; @@ -234,7 +223,8 @@ async fn sync_helper( // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch let mutex_insert = Arc::clone( - db.globals + services() + .globals .roomid_mutex_insert .write() .unwrap() @@ -247,9 +237,15 @@ async fn sync_helper( let timeline_pdus; let limited; - if db.rooms.last_timeline_count(&sender_user, &room_id)? > since { - let mut non_timeline_pdus = db + if services() + .rooms + .timeline + .last_timeline_count(&sender_user, &room_id)? + > since + { + let mut non_timeline_pdus = services() .rooms + .timeline .pdus_until(&sender_user, &room_id, u64::MAX)? .filter_map(|r| { // Filter out buggy events @@ -259,7 +255,9 @@ async fn sync_helper( r.ok() }) .take_while(|(pduid, _)| { - db.rooms + services() + .rooms + .timeline .pdu_count(pduid) .map_or(false, |count| count > since) }); @@ -282,10 +280,10 @@ async fn sync_helper( } let send_notification_counts = !timeline_pdus.is_empty() - || db + || services() .rooms - .edus - .last_privateread_update(&sender_user, &room_id)? + .user + .last_notification_read(&sender_user, &room_id)? > since; let mut timeline_users = HashSet::new(); @@ -293,24 +291,40 @@ async fn sync_helper( timeline_users.insert(event.sender.as_str().to_owned()); } - db.rooms - .lazy_load_confirm_delivery(&sender_user, &sender_device, &room_id, since)?; + services().rooms.lazy_loading.lazy_load_confirm_delivery( + &sender_user, + &sender_device, + &room_id, + since, + )?; // Database queries: - let current_shortstatehash = if let Some(s) = db.rooms.current_shortstatehash(&room_id)? { - s - } else { - error!("Room {} has no state", room_id); - continue; - }; + let current_shortstatehash = + if let Some(s) = services().rooms.state.get_room_shortstatehash(&room_id)? { + s + } else { + error!("Room {} has no state", room_id); + continue; + }; - let since_shortstatehash = db.rooms.get_token_shortstatehash(&room_id, since)?; + let since_shortstatehash = services() + .rooms + .user + .get_token_shortstatehash(&room_id, since)?; // Calculates joined_member_count, invited_member_count and heroes let calculate_counts = || { - let joined_member_count = db.rooms.room_joined_count(&room_id)?.unwrap_or(0); - let invited_member_count = db.rooms.room_invited_count(&room_id)?.unwrap_or(0); + let joined_member_count = services() + .rooms + .state_cache + .room_joined_count(&room_id)? + .unwrap_or(0); + let invited_member_count = services() + .rooms + .state_cache + .room_invited_count(&room_id)? + .unwrap_or(0); // Recalculate heroes (first 5 members) let mut heroes = Vec::new(); @@ -319,8 +333,9 @@ async fn sync_helper( // Go through all PDUs and for each member event, check if the user is still joined or // invited until we have 5 or we reach the end - for hero in db + for hero in services() .rooms + .timeline .all_pdus(&sender_user, &room_id)? .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter(|(_, pdu)| pdu.kind == RoomEventType::RoomMember) @@ -339,8 +354,11 @@ async fn sync_helper( if matches!( content.membership, MembershipState::Join | MembershipState::Invite - ) && (db.rooms.is_joined(&user_id, &room_id)? - || db.rooms.is_invited(&user_id, &room_id)?) + ) && (services().rooms.state_cache.is_joined(&user_id, &room_id)? + || services() + .rooms + .state_cache + .is_invited(&user_id, &room_id)?) { Ok::<_, Error>(Some(state_key.clone())) } else { @@ -370,28 +388,57 @@ async fn sync_helper( )) }; + let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash + .and_then(|shortstatehash| { + services() + .rooms + .state_accessor + .state_get( + shortstatehash, + &StateEventType::RoomMember, + sender_user.as_str(), + ) + .transpose() + }) + .transpose()? + .and_then(|pdu| { + serde_json::from_str(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }); + + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + let ( heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events, - ) = if since_shortstatehash.is_none() { + ) = if since_shortstatehash.is_none() || joined_since_last_sync { // Probably since = 0, we will do an initial sync let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; + let current_state_ids = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); let mut i = 0; for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services() + .rooms + .short + .get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -407,8 +454,10 @@ async fn sync_helper( } else if !lazy_load_enabled || body.full_state || timeline_users.contains(&state_key) + // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 + || *sender_user == state_key { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -417,7 +466,7 @@ async fn sync_helper( }; // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(state_key.as_ref()) { + if let Ok(uid) = UserId::parse(&state_key) { lazy_loaded.insert(uid); } state_events.push(pdu); @@ -430,12 +479,15 @@ async fn sync_helper( } // Reset lazy loading because this is an initial sync - db.rooms - .lazy_load_reset(&sender_user, &sender_device, &room_id)?; + services().rooms.lazy_loading.lazy_load_reset( + &sender_user, + &sender_device, + &room_id, + )?; // The state_events above should contain all timeline_users, let's mark them as lazy // loaded. - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_loading.lazy_load_mark_sent( &sender_user, &sender_device, &room_id, @@ -457,32 +509,24 @@ async fn sync_helper( // Incremental /sync let since_shortstatehash = since_shortstatehash.unwrap(); - let since_sender_member: Option<RoomMemberEventContent> = db - .rooms - .state_get( - since_shortstatehash, - &StateEventType::RoomMember, - sender_user.as_str(), - )? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); - - let joined_since_last_sync = since_sender_member - .map_or(true, |member| member.membership != MembershipState::Join); - let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); if since_shortstatehash != current_shortstatehash { - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; - let since_state_ids = db.rooms.state_full_ids(since_shortstatehash).await?; + let current_state_ids = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + let since_state_ids = services() + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; for (key, id) in current_state_ids { if body.full_state || since_state_ids.get(&key) != Some(&id) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -515,14 +559,14 @@ async fn sync_helper( continue; } - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_loading.lazy_load_was_sent_before( &sender_user, &sender_device, &room_id, &event.sender, )? || lazy_load_send_redundant { - if let Some(member_event) = db.rooms.room_state_get( + if let Some(member_event) = services().rooms.state_accessor.room_state_get( &room_id, &StateEventType::RoomMember, event.sender.as_str(), @@ -533,7 +577,7 @@ async fn sync_helper( } } - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_loading.lazy_load_mark_sent( &sender_user, &sender_device, &room_id, @@ -541,14 +585,17 @@ async fn sync_helper( next_batch, ); - let encrypted_room = db + let encrypted_room = services() .rooms + .state_accessor .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .is_some(); - let since_encryption = - db.rooms - .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?; + let since_encryption = services().rooms.state_accessor.state_get( + since_shortstatehash, + &StateEventType::RoomEncryption, + "", + )?; // Calculations: let new_encrypted_room = encrypted_room && since_encryption.is_none(); @@ -580,7 +627,7 @@ async fn sync_helper( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { + if !share_encrypted_room(&sender_user, &user_id, &room_id)? { device_list_updates.insert(user_id); } } @@ -597,7 +644,9 @@ async fn sync_helper( if joined_since_last_sync && encrypted_room || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_updates.extend( - db.rooms + services() + .rooms + .state_cache .room_members(&room_id) .flatten() .filter(|user_id| { @@ -606,8 +655,7 @@ async fn sync_helper( }) .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(&db, &sender_user, user_id, &room_id) - .unwrap_or(false) + !share_encrypted_room(&sender_user, user_id, &room_id).unwrap_or(false) }), ); } @@ -629,14 +677,17 @@ async fn sync_helper( // Look for device list updates in this room device_list_updates.extend( - db.users - .keys_changed(&room_id.to_string(), since, None) + services() + .users + .keys_changed(room_id.as_ref(), since, None) .filter_map(|r| r.ok()), ); let notification_count = if send_notification_counts { Some( - db.rooms + services() + .rooms + .user .notification_count(&sender_user, &room_id)? .try_into() .expect("notification count can't go that high"), @@ -647,7 +698,9 @@ async fn sync_helper( let highlight_count = if send_notification_counts { Some( - db.rooms + services() + .rooms + .user .highlight_count(&sender_user, &room_id)? .try_into() .expect("highlight count can't go that high"), @@ -659,7 +712,9 @@ async fn sync_helper( let prev_batch = timeline_pdus .first() .map_or(Ok::<_, Error>(None), |(pdu_id, _)| { - Ok(Some(db.rooms.pdu_count(pdu_id)?.to_string())) + Ok(Some( + services().rooms.timeline.pdu_count(pdu_id)?.to_string(), + )) })?; let room_events: Vec<_> = timeline_pdus @@ -667,18 +722,19 @@ async fn sync_helper( .map(|(_, pdu)| pdu.to_sync_room_event()) .collect(); - let mut edus: Vec<_> = db + let mut edus: Vec<_> = services() .rooms .edus + .read_receipt .readreceipts_since(&room_id, since) .filter_map(|r| r.ok()) // Filter out buggy events .map(|(_, _, v)| v) .collect(); - if db.rooms.edus.last_typing_update(&room_id, &db.globals)? > since { + if services().rooms.edus.typing.last_typing_update(&room_id)? > since { edus.push( serde_json::from_str( - &serde_json::to_string(&db.rooms.edus.typings_all(&room_id)?) + &serde_json::to_string(&services().rooms.edus.typing.typings_all(&room_id)?) .expect("event is valid, we just created it"), ) .expect("event is valid, we just created it"), @@ -686,12 +742,15 @@ async fn sync_helper( } // Save the state after this sync so we can send the correct state diff next sync - db.rooms - .associate_token_shortstatehash(&room_id, next_batch, current_shortstatehash)?; + services().rooms.user.associate_token_shortstatehash( + &room_id, + next_batch, + current_shortstatehash, + )?; let joined_room = JoinedRoom { account_data: RoomAccountData { - events: db + events: services() .account_data .changes_since(Some(&room_id), &sender_user, since)? .into_iter() @@ -723,6 +782,7 @@ async fn sync_helper( .collect(), }, ephemeral: Ephemeral { events: edus }, + unread_thread_notifications: BTreeMap::new(), }; if !joined_room.is_empty() { @@ -730,10 +790,11 @@ async fn sync_helper( } // Take presence updates from this room - for (user_id, presence) in - db.rooms - .edus - .presence_since(&room_id, since, &db.rooms, &db.globals)? + for (user_id, presence) in services() + .rooms + .edus + .presence + .presence_since(&room_id, since)? { match presence_updates.entry(user_id) { Entry::Vacant(v) => { @@ -765,14 +826,21 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = db.rooms.rooms_left(&sender_user).collect(); + let all_left_rooms: Vec<_> = services() + .rooms + .state_cache + .rooms_left(&sender_user) + .collect(); for result in all_left_rooms { - let (room_id, left_state_events) = result?; + let (room_id, _) = result?; + + let mut left_state_events = Vec::new(); { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = Arc::clone( - db.globals + services() + .globals .roomid_mutex_insert .write() .unwrap() @@ -783,13 +851,98 @@ async fn sync_helper( drop(insert_lock); } - let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; + let left_count = services() + .rooms + .state_cache + .get_left_count(&room_id, &sender_user)?; // Left before last sync if Some(since) >= left_count { continue; } + if !services().rooms.metadata.exists(&room_id)? { + // This is just a rejected invite, not a room we know + continue; + } + + let since_shortstatehash = services() + .rooms + .user + .get_token_shortstatehash(&room_id, since)?; + + let since_state_ids = match since_shortstatehash { + Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, + None => HashMap::new(), + }; + + let left_event_id = match services().rooms.state_accessor.room_state_get_id( + &room_id, + &StateEventType::RoomMember, + sender_user.as_str(), + )? { + Some(e) => e, + None => { + error!("Left room but no left state event"); + continue; + } + }; + + let left_shortstatehash = match services() + .rooms + .state_accessor + .pdu_shortstatehash(&left_event_id)? + { + Some(s) => s, + None => { + error!("Leave event has no state"); + continue; + } + }; + + let mut left_state_ids = services() + .rooms + .state_accessor + .state_full_ids(left_shortstatehash) + .await?; + + let leave_shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; + + left_state_ids.insert(leave_shortstatekey, left_event_id); + + let mut i = 0; + for (key, id) in left_state_ids { + if body.full_state || since_state_ids.get(&key) != Some(&id) { + let (event_type, state_key) = + services().rooms.short.get_statekey_from_short(key)?; + + if !lazy_load_enabled + || event_type != StateEventType::RoomMember + || body.full_state + // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 + || *sender_user == state_key + { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + } + }; + + left_state_events.push(pdu.to_sync_state_event()); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + } + } + left_rooms.insert( room_id.clone(), LeftRoom { @@ -807,14 +960,19 @@ async fn sync_helper( } let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = db.rooms.rooms_invited(&sender_user).collect(); + let all_invited_rooms: Vec<_> = services() + .rooms + .state_cache + .rooms_invited(&sender_user) + .collect(); for result in all_invited_rooms { let (room_id, invite_state_events) = result?; { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = Arc::clone( - db.globals + services() + .globals .roomid_mutex_insert .write() .unwrap() @@ -825,7 +983,10 @@ async fn sync_helper( drop(insert_lock); } - let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; + let invite_count = services() + .rooms + .state_cache + .get_invite_count(&room_id, &sender_user)?; // Invited before last sync if Some(since) >= invite_count { @@ -843,13 +1004,16 @@ async fn sync_helper( } for user_id in left_encrypted_users { - let still_share_encrypted_room = db + let still_share_encrypted_room = services() .rooms + .user .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .filter_map(|r| r.ok()) .filter_map(|other_room_id| { Some( - db.rooms + services() + .rooms + .state_accessor .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .ok()? .is_some(), @@ -864,7 +1028,8 @@ async fn sync_helper( } // Remove all to-device events the device received *last time* - db.users + services() + .users .remove_to_device_events(&sender_user, &sender_device, since)?; let response = sync_events::v3::Response { @@ -877,12 +1042,12 @@ async fn sync_helper( }, presence: Presence { events: presence_updates - .into_iter() - .map(|(_, v)| Raw::new(&v).expect("PresenceEvent always serializes successfully")) + .into_values() + .map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully")) .collect(), }, account_data: GlobalAccountData { - events: db + events: services() .account_data .changes_since(None, &sender_user, since)? .into_iter() @@ -897,9 +1062,11 @@ async fn sync_helper( changed: device_list_updates.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: db.users.count_one_time_keys(&sender_user, &sender_device)?, + device_one_time_keys_count: services() + .users + .count_one_time_keys(&sender_user, &sender_device)?, to_device: ToDevice { - events: db + events: services() .users .get_to_device_events(&sender_user, &sender_device)?, }, @@ -928,21 +1095,22 @@ async fn sync_helper( } } -#[tracing::instrument(skip(db))] fn share_encrypted_room( - db: &Database, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, ) -> Result<bool> { - Ok(db + Ok(services() .rooms + .user .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? .filter_map(|r| r.ok()) .filter(|room_id| room_id != ignore_room) .filter_map(|other_room_id| { Some( - db.rooms + services() + .rooms + .state_accessor .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .ok()? .is_some(), diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs new file mode 100644 index 0000000..16f1600 --- /dev/null +++ b/src/api/client_server/tag.rs @@ -0,0 +1,126 @@ +use crate::{services, Error, Result, Ruma}; +use ruma::{ + api::client::tag::{create_tag, delete_tag, get_tags}, + events::{ + tag::{TagEvent, TagEventContent}, + RoomAccountDataEventType, + }, +}; +use std::collections::BTreeMap; + +/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` +/// +/// Adds a tag to the room. +/// +/// - Inserts the tag into the tag event of the room account data. +pub async fn update_tag_route( + body: Ruma<create_tag::v3::Request>, +) -> Result<create_tag::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let event = services().account_data.get( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + )?; + + let mut tags_event = event + .map(|e| { + serde_json::from_str(e.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db.")) + }) + .unwrap_or_else(|| { + Ok(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }) + })?; + + tags_event + .content + .tags + .insert(body.tag.clone().into(), body.tag_info.clone()); + + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + )?; + + Ok(create_tag::v3::Response {}) +} + +/// # `DELETE /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` +/// +/// Deletes a tag from the room. +/// +/// - Removes the tag from the tag event of the room account data. +pub async fn delete_tag_route( + body: Ruma<delete_tag::v3::Request>, +) -> Result<delete_tag::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let event = services().account_data.get( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + )?; + + let mut tags_event = event + .map(|e| { + serde_json::from_str(e.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db.")) + }) + .unwrap_or_else(|| { + Ok(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }) + })?; + + tags_event.content.tags.remove(&body.tag.clone().into()); + + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + )?; + + Ok(delete_tag::v3::Response {}) +} + +/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags` +/// +/// Returns tags on the room. +/// +/// - Gets the tag event of the room account data. +pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let event = services().account_data.get( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + )?; + + let tags_event = event + .map(|e| { + serde_json::from_str(e.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db.")) + }) + .unwrap_or_else(|| { + Ok(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }) + })?; + + Ok(get_tags::v3::Response { + tags: tags_event.content.tags, + }) +} diff --git a/src/client_server/thirdparty.rs b/src/api/client_server/thirdparty.rs index 5665ad6..c2c1adf 100644 --- a/src/client_server/thirdparty.rs +++ b/src/api/client_server/thirdparty.rs @@ -7,7 +7,7 @@ use std::collections::BTreeMap; /// /// TODO: Fetches all metadata about protocols supported by the homeserver. pub async fn get_protocols_route( - _body: Ruma<get_protocols::v3::IncomingRequest>, + _body: Ruma<get_protocols::v3::Request>, ) -> Result<get_protocols::v3::Response> { // TODO Ok(get_protocols::v3::Response { diff --git a/src/client_server/to_device.rs b/src/api/client_server/to_device.rs index 51441dd..26db4e4 100644 --- a/src/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -1,7 +1,7 @@ use ruma::events::ToDeviceEventType; use std::collections::BTreeMap; -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma}; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -14,14 +14,13 @@ use ruma::{ /// /// Send a to-device event to a set of client devices. pub async fn send_event_to_device_route( - db: DatabaseGuard, - body: Ruma<send_event_to_device::v3::IncomingRequest>, + body: Ruma<send_event_to_device::v3::Request>, ) -> Result<send_event_to_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); // Check if this is a new transaction id - if db + if services() .transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? .is_some() @@ -31,44 +30,46 @@ pub async fn send_event_to_device_route( for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { - if target_user_id.server_name() != db.globals.server_name() { + if target_user_id.server_name() != services().globals.server_name() { let mut map = BTreeMap::new(); map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); messages.insert(target_user_id.clone(), map); + let count = services().globals.next_count()?; - db.sending.send_reliable_edu( + services().sending.send_reliable_edu( target_user_id.server_name(), serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( DirectDeviceContent { sender: sender_user.clone(), ev_type: ToDeviceEventType::from(&*body.event_type), - message_id: body.txn_id.to_owned(), + message_id: count.to_string().into(), messages, }, )) .expect("DirectToDevice EDU can be serialized"), - db.globals.next_count()?, + count, )?; continue; } match target_device_id_maybe { - DeviceIdOrAllDevices::DeviceId(target_device_id) => db.users.add_to_device_event( - sender_user, - target_user_id, - &target_device_id, - &body.event_type, - event.deserialize_as().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") - })?, - &db.globals, - )?, + DeviceIdOrAllDevices::DeviceId(target_device_id) => { + services().users.add_to_device_event( + sender_user, + target_user_id, + target_device_id, + &body.event_type, + event.deserialize_as().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") + })?, + )? + } DeviceIdOrAllDevices::AllDevices => { - for target_device_id in db.users.all_device_ids(target_user_id) { - db.users.add_to_device_event( + for target_device_id in services().users.all_device_ids(target_user_id) { + services().users.add_to_device_event( sender_user, target_user_id, &target_device_id?, @@ -76,7 +77,6 @@ pub async fn send_event_to_device_route( event.deserialize_as().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") })?, - &db.globals, )?; } } @@ -85,10 +85,9 @@ pub async fn send_event_to_device_route( } // Save transaction id with empty data - db.transaction_ids + services() + .transaction_ids .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; - db.flush()?; - Ok(send_event_to_device::v3::Response {}) } diff --git a/src/client_server/typing.rs b/src/api/client_server/typing.rs index cac5a5f..43217e1 100644 --- a/src/client_server/typing.rs +++ b/src/api/client_server/typing.rs @@ -1,18 +1,21 @@ -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{services, utils, Error, Result, Ruma}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// /// Sets the typing state of the sender user. pub async fn create_typing_event_route( - db: DatabaseGuard, - body: Ruma<create_typing_event::v3::IncomingRequest>, + body: Ruma<create_typing_event::v3::Request>, ) -> Result<create_typing_event::v3::Response> { use create_typing_event::v3::Typing; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You are not in this room.", @@ -20,16 +23,17 @@ pub async fn create_typing_event_route( } if let Typing::Yes(duration) = body.state { - db.rooms.edus.typing_add( + services().rooms.edus.typing.typing_add( sender_user, &body.room_id, duration.as_millis() as u64 + utils::millis_since_unix_epoch(), - &db.globals, )?; } else { - db.rooms + services() + .rooms .edus - .typing_remove(sender_user, &body.room_id, &db.globals)?; + .typing + .typing_remove(sender_user, &body.room_id)?; } Ok(create_typing_event::v3::Response {}) diff --git a/src/client_server/unversioned.rs b/src/api/client_server/unversioned.rs index 8a5c3d2..526598b 100644 --- a/src/client_server/unversioned.rs +++ b/src/api/client_server/unversioned.rs @@ -15,7 +15,7 @@ use crate::{Result, Ruma}; /// Note: Unstable features are used while developing new features. Clients should avoid using /// unstable features in their stable releases pub async fn get_supported_versions_route( - _body: Ruma<get_supported_versions::IncomingRequest>, + _body: Ruma<get_supported_versions::Request>, ) -> Result<get_supported_versions::Response> { let resp = get_supported_versions::Response { versions: vec![ diff --git a/src/client_server/user_directory.rs b/src/api/client_server/user_directory.rs index 349c139..c30bac5 100644 --- a/src/client_server/user_directory.rs +++ b/src/api/client_server/user_directory.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{services, Result, Ruma}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -14,20 +14,19 @@ use ruma::{ /// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) /// and don't share a room with the sender pub async fn search_users_route( - db: DatabaseGuard, - body: Ruma<search_users::v3::IncomingRequest>, + body: Ruma<search_users::v3::Request>, ) -> Result<search_users::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = u64::from(body.limit) as usize; - let mut users = db.users.iter().filter_map(|user_id| { + let mut users = services().users.iter().filter_map(|user_id| { // Filter out buggy users (they should not exist, but you never know...) let user_id = user_id.ok()?; let user = search_users::v3::User { user_id: user_id.clone(), - display_name: db.users.displayname(&user_id).ok()?, - avatar_url: db.users.avatar_url(&user_id).ok()?, + display_name: services().users.displayname(&user_id).ok()?, + avatar_url: services().users.avatar_url(&user_id).ok()?, }; let user_id_matches = user @@ -49,30 +48,34 @@ pub async fn search_users_route( return None; } - let user_is_in_public_rooms = - db.rooms - .rooms_joined(&user_id) - .filter_map(|r| r.ok()) - .any(|room| { - db.rooms - .room_state_get(&room, &StateEventType::RoomJoinRules, "") - .map_or(false, |event| { - event.map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| { - r.join_rule == JoinRule::Public - }) - }) + let user_is_in_public_rooms = services() + .rooms + .state_cache + .rooms_joined(&user_id) + .filter_map(|r| r.ok()) + .any(|room| { + services() + .rooms + .state_accessor + .room_state_get(&room, &StateEventType::RoomJoinRules, "") + .map_or(false, |event| { + event.map_or(false, |event| { + serde_json::from_str(event.content.get()) + .map_or(false, |r: RoomJoinRulesEventContent| { + r.join_rule == JoinRule::Public + }) }) - }); + }) + }); if user_is_in_public_rooms { return Some(user); } - let user_is_in_shared_rooms = db + let user_is_in_shared_rooms = services() .rooms - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) + .user + .get_shared_rooms(vec![sender_user.clone(), user_id]) .ok()? .next() .is_some(); diff --git a/src/client_server/voip.rs b/src/api/client_server/voip.rs index 7e9de31..4990c17 100644 --- a/src/client_server/voip.rs +++ b/src/api/client_server/voip.rs @@ -1,5 +1,5 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; -use hmac::{Hmac, Mac, NewMac}; +use crate::{services, Result, Ruma}; +use hmac::{Hmac, Mac}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use sha1::Sha1; use std::time::{Duration, SystemTime}; @@ -10,16 +10,15 @@ type HmacSha1 = Hmac<Sha1>; /// /// TODO: Returns information about the recommended turn server. pub async fn turn_server_route( - db: DatabaseGuard, - body: Ruma<get_turn_server_info::v3::IncomingRequest>, + body: Ruma<get_turn_server_info::v3::Request>, ) -> Result<get_turn_server_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let turn_secret = db.globals.turn_secret(); + let turn_secret = services().globals.turn_secret().clone(); let (username, password) = if !turn_secret.is_empty() { let expiry = SecondsSinceUnixEpoch::from_system_time( - SystemTime::now() + Duration::from_secs(db.globals.turn_ttl()), + SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), ) .expect("time is valid"); @@ -34,15 +33,15 @@ pub async fn turn_server_route( (username, password) } else { ( - db.globals.turn_username().clone(), - db.globals.turn_password().clone(), + services().globals.turn_username().clone(), + services().globals.turn_password().clone(), ) }; Ok(get_turn_server_info::v3::Response { username, password, - uris: db.globals.turn_uris().to_vec(), - ttl: Duration::from_secs(db.globals.turn_ttl()), + uris: services().globals.turn_uris().to_vec(), + ttl: Duration::from_secs(services().globals.turn_ttl()), }) } diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..0d2cd66 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,4 @@ +pub mod appservice_server; +pub mod client_server; +pub mod ruma_wrapper; +pub mod server_server; diff --git a/src/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 45e9d9a..ed28f9d 100644 --- a/src/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -17,14 +17,13 @@ use bytes::{BufMut, Bytes, BytesMut}; use http::StatusCode; use ruma::{ api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, - signatures::CanonicalJsonValue, - DeviceId, ServerName, UserId, + CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId, }; use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{database::DatabaseGuard, server_server, Error, Result}; +use crate::{services, Error, Result}; #[async_trait] impl<T, B> FromRequest<B> for Ruma<T> @@ -44,7 +43,6 @@ where } let metadata = T::METADATA; - let db = DatabaseGuard::from_request(req).await?; let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?; let path_params = Path::<Vec<String>>::from_request(req).await?; @@ -71,7 +69,7 @@ where let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); - let appservices = db.appservice.all().unwrap(); + let appservices = services().appservice.all().unwrap(); let appservice_registration = appservices.iter().find(|(_id, registration)| { registration .get("as_token") @@ -82,7 +80,7 @@ where let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) = appservice_registration { match metadata.authentication { - AuthScheme::AccessToken | AuthScheme::QueryOnlyAccessToken => { + AuthScheme::AccessToken => { let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( @@ -91,14 +89,14 @@ where .unwrap() .as_str() .unwrap(), - db.globals.server_name(), + services().globals.server_name(), ) .unwrap() }, |s| UserId::parse(s).unwrap(), ); - if !db.users.exists(&user_id).unwrap() { + if !services().users.exists(&user_id).unwrap() { return Err(Error::BadRequest( ErrorKind::Forbidden, "User does not exist.", @@ -113,7 +111,7 @@ where } } else { match metadata.authentication { - AuthScheme::AccessToken | AuthScheme::QueryOnlyAccessToken => { + AuthScheme::AccessToken => { let token = match token { Some(token) => token, _ => { @@ -124,7 +122,7 @@ where } }; - match db.users.find_from_token(token).unwrap() { + match services().users.find_from_token(token).unwrap() { None => { return Err(Error::BadRequest( ErrorKind::UnknownToken { soft_logout: false }, @@ -133,7 +131,7 @@ where } Some((user_id, device_id)) => ( Some(user_id), - Some(Box::<DeviceId>::from(device_id)), + Some(OwnedDeviceId::from(device_id)), None, false, ), @@ -185,7 +183,7 @@ where ( "destination".to_owned(), CanonicalJsonValue::String( - db.globals.server_name().as_str().to_owned(), + services().globals.server_name().as_str().to_owned(), ), ), ( @@ -198,12 +196,11 @@ where request_map.insert("content".to_owned(), json_body.clone()); }; - let keys_result = server_server::fetch_signing_keys( - &db, - &x_matrix.origin, - vec![x_matrix.key.to_owned()], - ) - .await; + let keys_result = services() + .rooms + .event_handler + .fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_owned()]) + .await; let keys = match keys_result { Ok(b) => b, @@ -251,7 +248,7 @@ where if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", db.globals.server_name()) + UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid") }); @@ -261,7 +258,7 @@ where .and_then(|auth| auth.get("session")) .and_then(|session| session.as_str()) .and_then(|session| { - db.uiaa.get_uiaa_request( + services().uiaa.get_uiaa_request( &user_id, &sender_device.clone().unwrap_or_else(|| "".into()), session, @@ -284,7 +281,7 @@ where debug!("{:?}", http_request); let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { - warn!("{:?}", e); + warn!("{:?}\n{:?}", e, json_body); Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") })?; @@ -300,7 +297,7 @@ where } struct XMatrix { - origin: Box<ServerName>, + origin: OwnedServerName, key: String, // KeyName? sig: String, } @@ -311,8 +308,7 @@ impl Credentials for XMatrix { fn decode(value: &http::HeaderValue) -> Option<Self> { debug_assert!( value.as_bytes().starts_with(b"X-Matrix "), - "HeaderValue to decode should start with \"X-Matrix ..\", received = {:?}", - value, + "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", ); let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) diff --git a/src/ruma_wrapper.rs b/src/api/ruma_wrapper/mod.rs index 15360e5..ac4c825 100644 --- a/src/ruma_wrapper.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -1,6 +1,7 @@ use crate::Error; use ruma::{ - api::client::uiaa::UiaaResponse, signatures::CanonicalJsonValue, DeviceId, ServerName, UserId, + api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, + OwnedUserId, }; use std::ops::Deref; @@ -10,9 +11,9 @@ mod axum; /// Extractor for Ruma request structs pub struct Ruma<T> { pub body: T, - pub sender_user: Option<Box<UserId>>, - pub sender_device: Option<Box<DeviceId>>, - pub sender_servername: Option<Box<ServerName>>, + pub sender_user: Option<OwnedUserId>, + pub sender_device: Option<OwnedDeviceId>, + pub sender_servername: Option<OwnedServerName>, // This is None when body is not a valid string pub json_body: Option<CanonicalJsonValue>, pub from_appservice: bool, diff --git a/src/api/server_server.rs b/src/api/server_server.rs new file mode 100644 index 0000000..fc3e2c0 --- /dev/null +++ b/src/api/server_server.rs @@ -0,0 +1,1823 @@ +use crate::{ + api::client_server::{self, claim_keys_helper, get_keys_helper}, + service::pdu::{gen_event_id_canonical_json, PduBuilder}, + services, utils, Error, PduEvent, Result, Ruma, +}; +use axum::{response::IntoResponse, Json}; +use get_profile_information::v1::ProfileField; +use http::header::{HeaderValue, AUTHORIZATION}; + +use ruma::{ + api::{ + client::error::{Error as RumaError, ErrorKind}, + federation::{ + authorization::get_event_authorization, + device::get_devices::{self, v1::UserDevice}, + directory::{get_public_rooms, get_public_rooms_filtered}, + discovery::{get_server_keys, get_server_version, ServerSigningKeys, VerifyKey}, + event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, + keys::{claim_keys, get_keys}, + membership::{ + create_invite, + create_join_event::{self, RoomState}, + prepare_join_event, + }, + query::{get_profile_information, get_room_information}, + transactions::{ + edu::{DeviceListUpdateContent, DirectDeviceContent, Edu, SigningKeyUpdateContent}, + send_transaction_message, + }, + }, + EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingResponse, + SendAccessToken, + }, + directory::{Filter, RoomNetwork}, + events::{ + receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, + room::{ + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + }, + RoomEventType, StateEventType, + }, + serde::{Base64, JsonObject, Raw}, + to_device::DeviceIdOrAllDevices, + CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, + OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName, +}; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use std::{ + collections::BTreeMap, + fmt::Debug, + mem, + net::{IpAddr, SocketAddr}, + sync::{Arc, RwLock}, + time::{Duration, Instant, SystemTime}, +}; + +use tracing::{debug, error, info, warn}; + +/// Wraps either an literal IP address plus port, or a hostname plus complement +/// (colon-plus-port if it was specified). +/// +/// Note: A `FedDest::Named` might contain an IP address in string form if there +/// was no port specified to construct a SocketAddr with. +/// +/// # Examples: +/// ```rust +/// # use conduit::api::server_server::FedDest; +/// # fn main() -> Result<(), std::net::AddrParseError> { +/// FedDest::Literal("198.51.100.3:8448".parse()?); +/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); +/// FedDest::Named("matrix.example.org".to_owned(), "".to_owned()); +/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); +/// FedDest::Named("198.51.100.5".to_owned(), "".to_owned()); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FedDest { + Literal(SocketAddr), + Named(String, String), +} + +impl FedDest { + fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } + + fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, ref port) => host + port, + } + } + + fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } + + fn port(&self) -> Option<u16> { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } +} + +#[tracing::instrument(skip(request))] +pub(crate) async fn send_request<T: OutgoingRequest>( + destination: &ServerName, + request: T, +) -> Result<T::IncomingResponse> +where + T: Debug, +{ + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let mut write_destination_to_cache = false; + + let cached_result = services() + .globals + .actual_destination_cache + .read() + .unwrap() + .get(destination) + .cloned(); + + let (actual_destination, host) = if let Some(result) = cached_result { + result + } else { + write_destination_to_cache = true; + + let result = find_actual_destination(destination).await; + + (result.0, result.1.into_uri_string()) + }; + + let actual_destination_str = actual_destination.clone().into_https_string(); + + let mut http_request = request + .try_into_http_request::<Vec<u8>>( + &actual_destination_str, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!( + "Failed to find destination {}: {}", + actual_destination_str, e + ); + Error::BadServerResponse("Invalid destination") + })?; + + let mut request_map = serde_json::Map::new(); + + if !http_request.body().is_empty() { + request_map.insert( + "content".to_owned(), + serde_json::from_slice(http_request.body()) + .expect("body is valid json, we just created it"), + ); + }; + + request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); + request_map.insert( + "uri".to_owned(), + http_request + .uri() + .path_and_query() + .expect("all requests have a path") + .to_string() + .into(), + ); + request_map.insert( + "origin".to_owned(), + services().globals.server_name().as_str().into(), + ); + request_map.insert("destination".to_owned(), destination.as_str().into()); + + let mut request_json = + serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); + + ruma::signatures::sign_json( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut request_json, + ) + .expect("our request json is what ruma expects"); + + let request_json: serde_json::Map<String, serde_json::Value> = + serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); + + let signatures = request_json["signatures"] + .as_object() + .unwrap() + .values() + .map(|v| { + v.as_object() + .unwrap() + .iter() + .map(|(k, v)| (k, v.as_str().unwrap())) + }); + + for signature_server in signatures { + for s in signature_server { + http_request.headers_mut().insert( + AUTHORIZATION, + HeaderValue::from_str(&format!( + "X-Matrix origin={},key=\"{}\",sig=\"{}\"", + services().globals.server_name(), + s.0, + s.1 + )) + .unwrap(), + ); + } + } + + let reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); + + let url = reqwest_request.url().clone(); + + let response = services() + .globals + .federation_client() + .execute(reqwest_request) + .await; + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { + warn!( + "{} {}: {}", + url, + status, + String::from_utf8_lossy(&body) + .lines() + .collect::<Vec<_>>() + .join(" ") + ); + } + + let http_response = http_response_builder + .body(body) + .expect("reqwest body is valid http body"); + + if status == 200 { + let response = T::IncomingResponse::try_from_http_response(http_response); + if response.is_ok() && write_destination_to_cache { + services() + .globals + .actual_destination_cache + .write() + .unwrap() + .insert( + OwnedServerName::from(destination), + (actual_destination, host), + ); + } + + response.map_err(|e| { + warn!( + "Invalid 200 response from {} on: {} {}", + &destination, url, e + ); + Error::BadServerResponse("Server returned bad 200 response.") + }) + } else { + Err(Error::FederationError( + destination.to_owned(), + RumaError::from_http_response(http_response), + )) + } + } + Err(e) => { + warn!( + "Could not send request to {} at {}: {}", + destination, actual_destination_str, e + ); + Err(e.into()) + } + } +} + +fn get_ip_with_port(destination_str: &str) -> Option<FedDest> { + if let Ok(destination) = destination_str.parse::<SocketAddr>() { + Some(FedDest::Literal(destination)) + } else if let Ok(ip_addr) = destination_str.parse::<IpAddr>() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +fn add_port_to_hostname(destination_str: &str) -> FedDest { + let (host, port) = match destination_str.find(':') { + None => (destination_str, ":8448"), + Some(pos) => destination_str.split_at(pos), + }; + FedDest::Named(host.to_owned(), port.to_owned()) +} + +/// Returns: actual_destination, host header +/// Implemented according to the specification at https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names +/// Numbers in comments below refer to bullet points in linked section of specification +async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { + let destination_str = destination.as_str().to_owned(); + let mut hostname = destination_str.clone(); + let actual_destination = match get_ip_with_port(&destination_str) { + Some(host_port) => { + // 1: IP literal with provided or default port + host_port + } + None => { + if let Some(pos) = destination_str.find(':') { + // 2: Hostname with included port + let (host, port) = destination_str.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + match request_well_known(destination.as_str()).await { + // 3: A .well-known file is available + Some(delegated_hostname) => { + hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); + match get_ip_with_port(&delegated_hostname) { + Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file + None => { + if let Some(pos) = delegated_hostname.find(':') { + // 3.2: Hostname with port in .well-known file + let (host, port) = delegated_hostname.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + // Delegated hostname has no port in this branch + if let Some(hostname_override) = + query_srv_record(&delegated_hostname).await + { + // 3.3: SRV lookup successful + let force_port = hostname_override.port(); + + if let Ok(override_ip) = services() + .globals + .dns_resolver() + .lookup_ip(hostname_override.hostname()) + .await + { + services() + .globals + .tls_name_override + .write() + .unwrap() + .insert( + delegated_hostname.clone(), + ( + override_ip.iter().collect(), + force_port.unwrap_or(8448), + ), + ); + } else { + warn!("Using SRV record, but could not resolve to IP"); + } + + if let Some(port) = force_port { + FedDest::Named(delegated_hostname, format!(":{port}")) + } else { + add_port_to_hostname(&delegated_hostname) + } + } else { + // 3.4: No SRV records, just use the hostname from .well-known + add_port_to_hostname(&delegated_hostname) + } + } + } + } + } + // 4: No .well-known or an error occured + None => { + match query_srv_record(&destination_str).await { + // 4: SRV record found + Some(hostname_override) => { + let force_port = hostname_override.port(); + + if let Ok(override_ip) = services() + .globals + .dns_resolver() + .lookup_ip(hostname_override.hostname()) + .await + { + services() + .globals + .tls_name_override + .write() + .unwrap() + .insert( + hostname.clone(), + ( + override_ip.iter().collect(), + force_port.unwrap_or(8448), + ), + ); + } else { + warn!("Using SRV record, but could not resolve to IP"); + } + + if let Some(port) = force_port { + FedDest::Named(hostname.clone(), format!(":{port}")) + } else { + add_port_to_hostname(&hostname) + } + } + // 5: No SRV record found + None => add_port_to_hostname(&destination_str), + } + } + } + } + } + }; + + // Can't use get_ip_with_port here because we don't want to add a port + // to an IP address if it wasn't specified + let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { + FedDest::Literal(addr) + } else if let Ok(addr) = hostname.parse::<IpAddr>() { + FedDest::Named(addr.to_string(), ":8448".to_owned()) + } else if let Some(pos) = hostname.find(':') { + let (host, port) = hostname.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + FedDest::Named(hostname, ":8448".to_owned()) + }; + (actual_destination, hostname) +} + +async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> { + if let Ok(Some(host_port)) = services() + .globals + .dns_resolver() + .srv_lookup(format!("_matrix._tcp.{hostname}")) + .await + .map(|srv| { + srv.iter().next().map(|result| { + FedDest::Named( + result.target().to_string().trim_end_matches('.').to_owned(), + format!(":{}", result.port()), + ) + }) + }) + { + Some(host_port) + } else { + None + } +} + +async fn request_well_known(destination: &str) -> Option<String> { + let body: serde_json::Value = serde_json::from_str( + &services() + .globals + .default_client() + .get(&format!("https://{destination}/.well-known/matrix/server")) + .send() + .await + .ok()? + .text() + .await + .ok()?, + ) + .ok()?; + Some(body.get("m.server")?.as_str()?.to_owned()) +} + +/// # `GET /_matrix/federation/v1/version` +/// +/// Get version information on this server. +pub async fn get_server_version_route( + _body: Ruma<get_server_version::v1::Request>, +) -> Result<get_server_version::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + Ok(get_server_version::v1::Response { + server: Some(get_server_version::v1::Server { + name: Some("Conduit".to_owned()), + version: Some(env!("CARGO_PKG_VERSION").to_owned()), + }), + }) +} + +/// # `GET /_matrix/key/v2/server` +/// +/// Gets the public signing keys of this server. +/// +/// - Matrix does not support invalidating public keys, so the key returned by this will be valid +/// forever. +// Response type for this endpoint is Json because we need to calculate a signature for the response +pub async fn get_server_keys_route() -> Result<impl IntoResponse> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let mut verify_keys: BTreeMap<OwnedServerSigningKeyId, VerifyKey> = BTreeMap::new(); + verify_keys.insert( + format!("ed25519:{}", services().globals.keypair().version()) + .try_into() + .expect("found invalid server signing keys in DB"), + VerifyKey { + key: Base64::new(services().globals.keypair().public_key().to_vec()), + }, + ); + let mut response = serde_json::from_slice( + get_server_keys::v2::Response { + server_key: Raw::new(&ServerSigningKeys { + server_name: services().globals.server_name().to_owned(), + verify_keys, + old_verify_keys: BTreeMap::new(), + signatures: BTreeMap::new(), + valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + Duration::from_secs(86400 * 7), + ) + .expect("time is valid"), + }) + .expect("static conversion, no errors"), + } + .try_into_http_response::<Vec<u8>>() + .unwrap() + .body(), + ) + .unwrap(); + + ruma::signatures::sign_json( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut response, + ) + .unwrap(); + + Ok(Json(response)) +} + +/// # `GET /_matrix/key/v2/server/{keyId}` +/// +/// Gets the public signing keys of this server. +/// +/// - Matrix does not support invalidating public keys, so the key returned by this will be valid +/// forever. +pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { + get_server_keys_route().await +} + +/// # `POST /_matrix/federation/v1/publicRooms` +/// +/// Lists the public rooms on this server. +pub async fn get_public_rooms_filtered_route( + body: Ruma<get_public_rooms_filtered::v1::Request>, +) -> Result<get_public_rooms_filtered::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let response = client_server::get_public_rooms_filtered_helper( + None, + body.limit, + body.since.as_deref(), + &body.filter, + &body.room_network, + ) + .await?; + + Ok(get_public_rooms_filtered::v1::Response { + chunk: response.chunk, + prev_batch: response.prev_batch, + next_batch: response.next_batch, + total_room_count_estimate: response.total_room_count_estimate, + }) +} + +/// # `GET /_matrix/federation/v1/publicRooms` +/// +/// Lists the public rooms on this server. +pub async fn get_public_rooms_route( + body: Ruma<get_public_rooms::v1::Request>, +) -> Result<get_public_rooms::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let response = client_server::get_public_rooms_filtered_helper( + None, + body.limit, + body.since.as_deref(), + &Filter::default(), + &RoomNetwork::Matrix, + ) + .await?; + + Ok(get_public_rooms::v1::Response { + chunk: response.chunk, + prev_batch: response.prev_batch, + next_batch: response.next_batch, + total_room_count_estimate: response.total_room_count_estimate, + }) +} + +/// # `PUT /_matrix/federation/v1/send/{txnId}` +/// +/// Push EDUs and PDUs to this server. +pub async fn send_transaction_message_route( + body: Ruma<send_transaction_message::v1::Request>, +) -> Result<send_transaction_message::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + let mut resolved_map = BTreeMap::new(); + + let pub_key_map = RwLock::new(BTreeMap::new()); + + // This is all the auth_events that have been recursively fetched so they don't have to be + // deserialized over and over again. + // TODO: make this persist across requests but not in a DB Tree (in globals?) + // TODO: This could potentially also be some sort of trie (suffix tree) like structure so + // that once an auth event is known it would know (using indexes maybe) all of the auth + // events that it references. + // let mut auth_cache = EventMap::new(); + + for pdu in &body.pdus { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + + let room_id: OwnedRoomId = match value + .get("room_id") + .and_then(|id| RoomId::parse(id.as_str()?).ok()) + { + Some(id) => id, + None => { + // Event is invalid + continue; + } + }; + + let room_version_id = match services().rooms.state.get_room_version(&room_id) { + Ok(v) => v, + Err(_) => { + continue; + } + }; + + let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + continue; + } + }; + // We do not add the event_id field to the pdu here because of signature and hashes checks + + services() + .rooms + .event_handler + .acl_check(sender_servername, &room_id)?; + + let mutex = Arc::clone( + services() + .globals + .roomid_mutex_federation + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let start_time = Instant::now(); + resolved_map.insert( + event_id.clone(), + services() + .rooms + .event_handler + .handle_incoming_pdu( + sender_servername, + &event_id, + &room_id, + value, + true, + &pub_key_map, + ) + .await + .map(|_| ()), + ); + drop(mutex_lock); + + let elapsed = start_time.elapsed(); + debug!( + "Handling transaction of event {} took {}m{}s", + event_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } + + for pdu in &resolved_map { + if let Err(e) = pdu.1 { + if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { + warn!("Incoming PDU failed {:?}", pdu); + } + } + } + + for edu in body + .edus + .iter() + .filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok()) + { + match edu { + Edu::Presence(_) => {} + Edu::Receipt(receipt) => { + for (room_id, room_updates) in receipt.receipts { + for (user_id, user_updates) in room_updates.read { + if let Some((event_id, _)) = user_updates + .event_ids + .iter() + .filter_map(|id| { + services() + .rooms + .timeline + .get_pdu_count(id) + .ok() + .flatten() + .map(|r| (id, r)) + }) + .max_by_key(|(_, count)| *count) + { + let mut user_receipts = BTreeMap::new(); + user_receipts.insert(user_id.clone(), user_updates.data); + + let mut receipts = BTreeMap::new(); + receipts.insert(ReceiptType::Read, user_receipts); + + let mut receipt_content = BTreeMap::new(); + receipt_content.insert(event_id.to_owned(), receipts); + + let event = ReceiptEvent { + content: ReceiptEventContent(receipt_content), + room_id: room_id.clone(), + }; + services() + .rooms + .edus + .read_receipt + .readreceipt_update(&user_id, &room_id, event)?; + } else { + // TODO fetch missing events + info!("No known event ids in read receipt: {:?}", user_updates); + } + } + } + } + Edu::Typing(typing) => { + if services() + .rooms + .state_cache + .is_joined(&typing.user_id, &typing.room_id)? + { + if typing.typing { + services().rooms.edus.typing.typing_add( + &typing.user_id, + &typing.room_id, + 3000 + utils::millis_since_unix_epoch(), + )?; + } else { + services() + .rooms + .edus + .typing + .typing_remove(&typing.user_id, &typing.room_id)?; + } + } + } + Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { + services().users.mark_device_key_update(&user_id)?; + } + Edu::DirectToDevice(DirectDeviceContent { + sender, + ev_type, + message_id, + messages, + }) => { + // Check if this is a new transaction id + if services() + .transaction_ids + .existing_txnid(&sender, None, &message_id)? + .is_some() + { + continue; + } + + for (target_user_id, map) in &messages { + for (target_device_id_maybe, event) in map { + match target_device_id_maybe { + DeviceIdOrAllDevices::DeviceId(target_device_id) => { + services().users.add_to_device_event( + &sender, + target_user_id, + target_device_id, + &ev_type.to_string(), + event.deserialize_as().map_err(|e| { + warn!("To-Device event is invalid: {event:?} {e}"); + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) + })?, + )? + } + + DeviceIdOrAllDevices::AllDevices => { + for target_device_id in + services().users.all_device_ids(target_user_id) + { + services().users.add_to_device_event( + &sender, + target_user_id, + &target_device_id?, + &ev_type.to_string(), + event.deserialize_as().map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Event is invalid", + ) + })?, + )?; + } + } + } + } + } + + // Save transaction id with empty data + services() + .transaction_ids + .add_txnid(&sender, None, &message_id, &[])?; + } + Edu::SigningKeyUpdate(SigningKeyUpdateContent { + user_id, + master_key, + self_signing_key, + }) => { + if user_id.server_name() != sender_servername { + continue; + } + if let Some(master_key) = master_key { + services().users.add_cross_signing_keys( + &user_id, + &master_key, + &self_signing_key, + &None, + )?; + } + } + Edu::_Custom(_) => {} + } + } + + Ok(send_transaction_message::v1::Response { + pdus: resolved_map + .into_iter() + .map(|(e, r)| (e, r.map_err(|e| e.to_string()))) + .collect(), + }) +} + +/// # `GET /_matrix/federation/v1/event/{eventId}` +/// +/// Retrieves a single event from the server. +/// +/// - Only works if a user of this server is currently invited or joined the room +pub async fn get_event_route( + body: Ruma<get_event::v1::Request>, +) -> Result<get_event::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + let event = services() + .rooms + .timeline + .get_pdu_json(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Server is not in room", + )); + } + + Ok(get_event::v1::Response { + origin: services().globals.server_name().to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch::now(), + pdu: PduEvent::convert_to_outgoing_federation_event(event), + }) +} + +/// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` +/// +/// Retrieves events that the sender is missing. +pub async fn get_missing_events_route( + body: Ruma<get_missing_events::v1::Request>, +) -> Result<get_missing_events::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Server is not in room", + )); + } + + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; + + let mut queued_events = body.latest_events.clone(); + let mut events = Vec::new(); + + let mut i = 0; + while i < queued_events.len() && events.len() < u64::from(body.limit) as usize { + if let Some(pdu) = services().rooms.timeline.get_pdu_json(&queued_events[i])? { + let room_id_str = pdu + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let event_room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + + if event_room_id != body.room_id { + warn!( + "Evil event detected: Event {} found while searching in room {}", + queued_events[i], body.room_id + ); + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Evil event detected", + )); + } + + if body.earliest_events.contains(&queued_events[i]) { + i += 1; + continue; + } + queued_events.extend_from_slice( + &serde_json::from_value::<Vec<OwnedEventId>>( + serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| { + Error::bad_database("Event in db has no prev_events field.") + })?) + .expect("canonical json is valid json value"), + ) + .map_err(|_| Error::bad_database("Invalid prev_events content in pdu in db."))?, + ); + events.push(PduEvent::convert_to_outgoing_federation_event(pdu)); + } + i += 1; + } + + Ok(get_missing_events::v1::Response { events }) +} + +/// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` +/// +/// Retrieves the auth chain for a given event. +/// +/// - This does not include the event itself +pub async fn get_event_authorization_route( + body: Ruma<get_event_authorization::v1::Request>, +) -> Result<get_event_authorization::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Server is not in room.", + )); + } + + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; + + let event = services() + .rooms + .timeline + .get_pdu_json(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; + + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]) + .await?; + + Ok(get_event_authorization::v1::Response { + auth_chain: auth_chain_ids + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + }) +} + +/// # `GET /_matrix/federation/v1/state/{roomId}` +/// +/// Retrieves the current state of the room. +pub async fn get_room_state_route( + body: Ruma<get_room_state::v1::Request>, +) -> Result<get_room_state::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Server is not in room.", + )); + } + + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; + + let shortstatehash = services() + .rooms + .state_accessor + .pdu_shortstatehash(&body.event_id)? + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Pdu state not found.", + ))?; + + let pdus = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await? + .into_values() + .map(|id| { + PduEvent::convert_to_outgoing_federation_event( + services() + .rooms + .timeline + .get_pdu_json(&id) + .unwrap() + .unwrap(), + ) + }) + .collect(); + + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]) + .await?; + + Ok(get_room_state::v1::Response { + auth_chain: auth_chain_ids + .filter_map( + |id| match services().rooms.timeline.get_pdu_json(&id).ok()? { + Some(json) => Some(PduEvent::convert_to_outgoing_federation_event(json)), + None => { + error!("Could not find event json for {id} in db."); + None + } + }, + ) + .collect(), + pdus, + }) +} + +/// # `GET /_matrix/federation/v1/state_ids/{roomId}` +/// +/// Retrieves the current state of the room. +pub async fn get_room_state_ids_route( + body: Ruma<get_room_state_ids::v1::Request>, +) -> Result<get_room_state_ids::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Server is not in room.", + )); + } + + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; + + let shortstatehash = services() + .rooms + .state_accessor + .pdu_shortstatehash(&body.event_id)? + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Pdu state not found.", + ))?; + + let pdu_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await? + .into_values() + .map(|id| (*id).to_owned()) + .collect(); + + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]) + .await?; + + Ok(get_room_state_ids::v1::Response { + auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), + pdu_ids, + }) +} + +/// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}` +/// +/// Creates a join template. +pub async fn create_join_event_template_route( + body: Ruma<prepare_join_event::v1::Request>, +) -> Result<prepare_join_event::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + if !services().rooms.metadata.exists(&body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Room is unknown to this server.", + )); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(body.room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // TODO: Conduit does not implement restricted join rules yet, we always reject + let join_rules_event = services().rooms.state_accessor.room_state_get( + &body.room_id, + &StateEventType::RoomJoinRules, + "", + )?; + + let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose()?; + + if let Some(join_rules_event_content) = join_rules_event_content { + if matches!( + join_rules_event_content.join_rule, + JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. } + ) { + return Err(Error::BadRequest( + ErrorKind::UnableToAuthorizeJoin, + "Conduit does not support restricted rooms yet.", + )); + } + } + + let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; + if !body.ver.contains(&room_version_id) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Room version not supported.", + )); + } + + let content = to_raw_value(&RoomMemberEventContent { + avatar_url: None, + blurhash: None, + displayname: None, + is_direct: None, + membership: MembershipState::Join, + third_party_invite: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("member event is valid value"); + + let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + PduBuilder { + event_type: RoomEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + )?; + + drop(state_lock); + + pdu_json.remove("event_id"); + + Ok(prepare_join_event::v1::Response { + room_version: Some(room_version_id), + event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), + }) +} + +async fn create_join_event( + sender_servername: &ServerName, + room_id: &RoomId, + pdu: &RawJsonValue, +) -> Result<RoomState> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + if !services().rooms.metadata.exists(room_id)? { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Room is unknown to this server.", + )); + } + + services() + .rooms + .event_handler + .acl_check(sender_servername, room_id)?; + + // TODO: Conduit does not implement restricted join rules yet, we always reject + let join_rules_event = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomJoinRules, + "", + )?; + + let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose()?; + + if let Some(join_rules_event_content) = join_rules_event_content { + if matches!( + join_rules_event_content.join_rule, + JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. } + ) { + return Err(Error::BadRequest( + ErrorKind::UnableToAuthorizeJoin, + "Conduit does not support restricted rooms yet.", + )); + } + } + + // We need to return the state prior to joining, let's keep a reference to that here + let shortstatehash = services() + .rooms + .state + .get_room_shortstatehash(room_id)? + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Pdu state not found.", + ))?; + + let pub_key_map = RwLock::new(BTreeMap::new()); + // let mut auth_cache = EventMap::new(); + + // We do not add the event_id field to the pdu here because of signature and hashes checks + let room_version_id = services().rooms.state.get_room_version(room_id)?; + let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + } + }; + + let origin: OwnedServerName = serde_json::from_value( + serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event needs an origin field.", + ))?) + .expect("CanonicalJson is valid json value"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; + + let mutex = Arc::clone( + services() + .globals + .roomid_mutex_federation + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let pdu_id: Vec<u8> = services() + .rooms + .event_handler + .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) + .await? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; + drop(mutex_lock); + + let state_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await?; + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(room_id, state_ids.values().cloned().collect()) + .await?; + + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(|r| r.ok()) + .filter(|server| &**server != services().globals.server_name()); + + services().sending.send_pdu(servers, &pdu_id)?; + + Ok(RoomState { + auth_chain: auth_chain_ids + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + state: state_ids + .iter() + .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + event: None, // TODO: handle restricted joins + }) +} + +/// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}` +/// +/// Submits a signed join event. +pub async fn create_join_event_v1_route( + body: Ruma<create_join_event::v1::Request>, +) -> Result<create_join_event::v1::Response> { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + + Ok(create_join_event::v1::Response { room_state }) +} + +/// # `PUT /_matrix/federation/v2/send_join/{roomId}/{eventId}` +/// +/// Submits a signed join event. +pub async fn create_join_event_v2_route( + body: Ruma<create_join_event::v2::Request>, +) -> Result<create_join_event::v2::Response> { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + + Ok(create_join_event::v2::Response { room_state }) +} + +/// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}` +/// +/// Invites a remote user to a room. +pub async fn create_invite_route( + body: Ruma<create_invite::v2::Request>, +) -> Result<create_invite::v2::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; + + if !services() + .globals + .supported_room_versions() + .contains(&body.room_version) + { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: body.room_version.clone(), + }, + "Server does not support this room version.", + )); + } + + let mut signed_event = utils::to_canonical_object(&body.event) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; + + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut signed_event, + &body.room_version, + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + + // Generate event id + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&signed_event, &body.room_version) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + // Add event_id back + signed_event.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.to_string()), + ); + + let sender: OwnedUserId = serde_json::from_value( + signed_event + .get("sender") + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event had no sender field.", + ))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id."))?; + + let invited_user: Box<_> = serde_json::from_value( + signed_event + .get("state_key") + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Event had no state_key field.", + ))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; + + let mut invite_state = body.invite_room_state.clone(); + + let mut event: JsonObject = serde_json::from_str(body.event.get()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; + + event.insert("event_id".to_owned(), "$dummy".into()); + + let pdu: PduEvent = serde_json::from_value(event.into()).map_err(|e| { + warn!("Invalid invite event: {}", e); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.") + })?; + + invite_state.push(pdu.to_stripped_state_event()); + + // If we are active in the room, the remote server will notify us about the join via /send + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), &body.room_id)? + { + services().rooms.state_cache.update_membership( + &body.room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + true, + )?; + } + + Ok(create_invite::v2::Response { + event: PduEvent::convert_to_outgoing_federation_event(signed_event), + }) +} + +/// # `GET /_matrix/federation/v1/user/devices/{userId}` +/// +/// Gets information on all devices of the user. +pub async fn get_devices_route( + body: Ruma<get_devices::v1::Request>, +) -> Result<get_devices::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + Ok(get_devices::v1::Response { + user_id: body.user_id.clone(), + stream_id: services() + .users + .get_devicelist_version(&body.user_id)? + .unwrap_or(0) + .try_into() + .expect("version will not grow that large"), + devices: services() + .users + .all_devices_metadata(&body.user_id) + .filter_map(|r| r.ok()) + .filter_map(|metadata| { + Some(UserDevice { + keys: services() + .users + .get_device_keys(&body.user_id, &metadata.device_id) + .ok()??, + device_id: metadata.device_id, + device_display_name: metadata.display_name, + }) + }) + .collect(), + master_key: services() + .users + .get_master_key(&body.user_id, &|u| u.server_name() == sender_servername)?, + self_signing_key: services() + .users + .get_self_signing_key(&body.user_id, &|u| u.server_name() == sender_servername)?, + }) +} + +/// # `GET /_matrix/federation/v1/query/directory` +/// +/// Resolve a room alias to a room id. +pub async fn get_room_information_route( + body: Ruma<get_room_information::v1::Request>, +) -> Result<get_room_information::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let room_id = services() + .rooms + .alias + .resolve_local_alias(&body.room_alias)? + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Room alias not found.", + ))?; + + Ok(get_room_information::v1::Response { + room_id, + servers: vec![services().globals.server_name().to_owned()], + }) +} + +/// # `GET /_matrix/federation/v1/query/profile` +/// +/// Gets information on a profile. +pub async fn get_profile_information_route( + body: Ruma<get_profile_information::v1::Request>, +) -> Result<get_profile_information::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let mut displayname = None; + let mut avatar_url = None; + let mut blurhash = None; + + match &body.field { + Some(ProfileField::DisplayName) => { + displayname = services().users.displayname(&body.user_id)? + } + Some(ProfileField::AvatarUrl) => { + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)? + } + // TODO: what to do with custom + Some(_) => {} + None => { + displayname = services().users.displayname(&body.user_id)?; + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)?; + } + } + + Ok(get_profile_information::v1::Response { + blurhash, + displayname, + avatar_url, + }) +} + +/// # `POST /_matrix/federation/v1/user/keys/query` +/// +/// Gets devices and identity keys for the given users. +pub async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<get_keys::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let result = get_keys_helper(None, &body.device_keys, |u| { + Some(u.server_name()) == body.sender_servername.as_deref() + }) + .await?; + + Ok(get_keys::v1::Response { + device_keys: result.device_keys, + master_keys: result.master_keys, + self_signing_keys: result.self_signing_keys, + }) +} + +/// # `POST /_matrix/federation/v1/user/keys/claim` +/// +/// Claims one-time keys. +pub async fn claim_keys_route( + body: Ruma<claim_keys::v1::Request>, +) -> Result<claim_keys::v1::Response> { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } + + let result = claim_keys_helper(&body.one_time_keys).await?; + + Ok(claim_keys::v1::Response { + one_time_keys: result.one_time_keys, + }) +} + +#[cfg(test)] +mod tests { + use super::{add_port_to_hostname, get_ip_with_port, FedDest}; + + #[test] + fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); + } + + #[test] + fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); + } + + #[test] + fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ) + } + + #[test] + fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ) + } +} diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs deleted file mode 100644 index a1b616b..0000000 --- a/src/client_server/membership.rs +++ /dev/null @@ -1,1082 +0,0 @@ -use crate::{ - client_server, - database::DatabaseGuard, - pdu::{EventHash, PduBuilder, PduEvent}, - server_server, utils, Database, Error, Result, Ruma, -}; -use ruma::{ - api::{ - client::{ - error::ErrorKind, - membership::{ - ban_user, forget_room, get_member_events, invite_user, join_room_by_id, - join_room_by_id_or_alias, joined_members, joined_rooms, kick_user, leave_room, - unban_user, IncomingThirdPartySigned, - }, - }, - federation::{self, membership::create_invite}, - }, - events::{ - room::{ - create::RoomCreateEventContent, - member::{MembershipState, RoomMemberEventContent}, - }, - RoomEventType, StateEventType, - }, - serde::{to_canonical_value, Base64, CanonicalJsonObject, CanonicalJsonValue}, - state_res::{self, RoomVersion}, - uint, EventId, RoomId, RoomVersionId, ServerName, UserId, -}; -use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap}, - iter, - sync::{Arc, RwLock}, - time::{Duration, Instant}, -}; -use tracing::{debug, error, warn}; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/join` -/// -/// Tries to join the sender user into a room. -/// -/// - If the server knowns about this room: creates the join event and does auth rules locally -/// - If the server does not know about the room: asks other servers over federation -pub async fn join_room_by_id_route( - db: DatabaseGuard, - body: Ruma<join_room_by_id::v3::IncomingRequest>, -) -> Result<join_room_by_id::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let mut servers = Vec::new(); // There is no body.server_name for /roomId/join - servers.extend( - db.rooms - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); - - servers.push(body.room_id.server_name().to_owned()); - - let ret = join_room_by_id_helper( - &db, - body.sender_user.as_deref(), - &body.room_id, - &servers, - body.third_party_signed.as_ref(), - ) - .await; - - db.flush()?; - - ret -} - -/// # `POST /_matrix/client/r0/join/{roomIdOrAlias}` -/// -/// Tries to join the sender user into a room. -/// -/// - If the server knowns about this room: creates the join event and does auth rules locally -/// - If the server does not know about the room: asks other servers over federation -pub async fn join_room_by_id_or_alias_route( - db: DatabaseGuard, - body: Ruma<join_room_by_id_or_alias::v3::IncomingRequest>, -) -> Result<join_room_by_id_or_alias::v3::Response> { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let body = body.body; - - let (servers, room_id) = match Box::<RoomId>::try_from(body.room_id_or_alias) { - Ok(room_id) => { - let mut servers = body.server_name.clone(); - servers.extend( - db.rooms - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); - - servers.push(room_id.server_name().to_owned()); - (servers, room_id) - } - Err(room_alias) => { - let response = client_server::get_alias_helper(&db, &room_alias).await?; - - (response.servers.into_iter().collect(), response.room_id) - } - }; - - let join_room_response = join_room_by_id_helper( - &db, - Some(sender_user), - &room_id, - &servers, - body.third_party_signed.as_ref(), - ) - .await?; - - db.flush()?; - - Ok(join_room_by_id_or_alias::v3::Response { - room_id: join_room_response.room_id, - }) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/leave` -/// -/// Tries to leave the sender user from a room. -/// -/// - This should always work if the user is currently joined. -pub async fn leave_room_route( - db: DatabaseGuard, - body: Ruma<leave_room::v3::IncomingRequest>, -) -> Result<leave_room::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - db.rooms.leave_room(sender_user, &body.room_id, &db).await?; - - db.flush()?; - - Ok(leave_room::v3::Response::new()) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/invite` -/// -/// Tries to send an invite event into the room. -pub async fn invite_user_route( - db: DatabaseGuard, - body: Ruma<invite_user::v3::IncomingRequest>, -) -> Result<invite_user::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if let invite_user::v3::IncomingInvitationRecipient::UserId { user_id } = &body.recipient { - invite_helper(sender_user, user_id, &body.room_id, &db, false).await?; - db.flush()?; - Ok(invite_user::v3::Response {}) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) - } -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/kick` -/// -/// Tries to send a kick event into the room. -pub async fn kick_user_route( - db: DatabaseGuard, - body: Ruma<kick_user::v3::IncomingRequest>, -) -> Result<kick_user::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let mut event: RoomMemberEventContent = serde_json::from_str( - db.rooms - .room_state_get( - &body.room_id, - &StateEventType::RoomMember, - &body.user_id.to_string(), - )? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot kick member that's not in the room.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - event.membership = MembershipState::Leave; - // TODO: reason - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - sender_user, - &body.room_id, - &db, - &state_lock, - )?; - - drop(state_lock); - - db.flush()?; - - Ok(kick_user::v3::Response::new()) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/ban` -/// -/// Tries to send a ban event into the room. -pub async fn ban_user_route( - db: DatabaseGuard, - body: Ruma<ban_user::v3::IncomingRequest>, -) -> Result<ban_user::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - // TODO: reason - - let event = db - .rooms - .room_state_get( - &body.room_id, - &StateEventType::RoomMember, - &body.user_id.to_string(), - )? - .map_or( - Ok(RoomMemberEventContent { - membership: MembershipState::Ban, - displayname: db.users.displayname(&body.user_id)?, - avatar_url: db.users.avatar_url(&body.user_id)?, - is_direct: None, - third_party_invite: None, - blurhash: db.users.blurhash(&body.user_id)?, - reason: None, - join_authorized_via_users_server: None, - }), - |event| { - serde_json::from_str(event.content.get()) - .map(|event: RoomMemberEventContent| RoomMemberEventContent { - membership: MembershipState::Ban, - ..event - }) - .map_err(|_| Error::bad_database("Invalid member event in database.")) - }, - )?; - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - sender_user, - &body.room_id, - &db, - &state_lock, - )?; - - drop(state_lock); - - db.flush()?; - - Ok(ban_user::v3::Response::new()) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/unban` -/// -/// Tries to send an unban event into the room. -pub async fn unban_user_route( - db: DatabaseGuard, - body: Ruma<unban_user::v3::IncomingRequest>, -) -> Result<unban_user::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let mut event: RoomMemberEventContent = serde_json::from_str( - db.rooms - .room_state_get( - &body.room_id, - &StateEventType::RoomMember, - &body.user_id.to_string(), - )? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot unban a user who is not banned.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - event.membership = MembershipState::Leave; - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - sender_user, - &body.room_id, - &db, - &state_lock, - )?; - - drop(state_lock); - - db.flush()?; - - Ok(unban_user::v3::Response::new()) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/forget` -/// -/// Forgets about a room. -/// -/// - If the sender user currently left the room: Stops sender user from receiving information about the room -/// -/// Note: Other devices of the user have no way of knowing the room was forgotten, so this has to -/// be called from every device -pub async fn forget_room_route( - db: DatabaseGuard, - body: Ruma<forget_room::v3::IncomingRequest>, -) -> Result<forget_room::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - db.rooms.forget(&body.room_id, sender_user)?; - - db.flush()?; - - Ok(forget_room::v3::Response::new()) -} - -/// # `POST /_matrix/client/r0/joined_rooms` -/// -/// Lists all rooms the user has joined. -pub async fn joined_rooms_route( - db: DatabaseGuard, - body: Ruma<joined_rooms::v3::Request>, -) -> Result<joined_rooms::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - Ok(joined_rooms::v3::Response { - joined_rooms: db - .rooms - .rooms_joined(sender_user) - .filter_map(|r| r.ok()) - .collect(), - }) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/members` -/// -/// Lists all joined users in a room (TODO: at a specific point in time, with a specific membership). -/// -/// - Only works if the user is currently joined -pub async fn get_member_events_route( - db: DatabaseGuard, - body: Ruma<get_member_events::v3::IncomingRequest>, -) -> Result<get_member_events::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - // TODO: check history visibility? - if !db.rooms.is_joined(sender_user, &body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this room.", - )); - } - - Ok(get_member_events::v3::Response { - chunk: db - .rooms - .room_state_full(&body.room_id) - .await? - .iter() - .filter(|(key, _)| key.0 == StateEventType::RoomMember) - .map(|(_, pdu)| pdu.to_member_event().into()) - .collect(), - }) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/joined_members` -/// -/// Lists all members of a room. -/// -/// - The sender user must be in the room -/// - TODO: An appservice just needs a puppet joined -pub async fn joined_members_route( - db: DatabaseGuard, - body: Ruma<joined_members::v3::IncomingRequest>, -) -> Result<joined_members::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - if !db.rooms.is_joined(sender_user, &body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You aren't a member of the room.", - )); - } - - let mut joined = BTreeMap::new(); - for user_id in db.rooms.room_members(&body.room_id).filter_map(|r| r.ok()) { - let display_name = db.users.displayname(&user_id)?; - let avatar_url = db.users.avatar_url(&user_id)?; - - joined.insert( - user_id, - joined_members::v3::RoomMember { - display_name, - avatar_url, - }, - ); - } - - Ok(joined_members::v3::Response { joined }) -} - -#[tracing::instrument(skip(db))] -async fn join_room_by_id_helper( - db: &Database, - sender_user: Option<&UserId>, - room_id: &RoomId, - servers: &[Box<ServerName>], - _third_party_signed: Option<&IncomingThirdPartySigned>, -) -> Result<join_room_by_id::v3::Response> { - let sender_user = sender_user.expect("user is authenticated"); - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Ask a remote server if we don't have this room - if !db.rooms.exists(room_id)? { - let mut make_join_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in joining.", - )); - - for remote_server in servers { - let make_join_response = db - .sending - .send_federation_request( - &db.globals, - remote_server, - federation::membership::prepare_join_event::v1::Request { - room_id, - user_id: sender_user, - ver: &db.globals.supported_room_versions(), - }, - ) - .await; - - make_join_response_and_server = make_join_response.map(|r| (r, remote_server)); - - if make_join_response_and_server.is_ok() { - break; - } - } - - let (make_join_response, remote_server) = make_join_response_and_server?; - - let room_version = match make_join_response.room_version { - Some(room_version) if db.rooms.is_supported_version(&db, &room_version) => room_version, - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; - - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - ); - - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - join_event_stub.remove("event_id"); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut join_event_stub, - &room_version, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - join_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - // It has enough fields to be called a proper event now - let join_event = join_event_stub; - - let send_join_response = db - .sending - .send_federation_request( - &db.globals, - remote_server, - federation::membership::create_join_event::v2::Request { - room_id, - event_id, - pdu: &PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - }, - ) - .await?; - - db.rooms.get_or_create_shortroomid(room_id, &db.globals)?; - - let parsed_pdu = PduEvent::from_id_val(event_id, join_event.clone()) - .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; - - let mut state = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); - - server_server::fetch_join_signing_keys( - &send_join_response, - &room_version, - &pub_key_map, - db, - ) - .await?; - - for result in send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, db)) - { - let (event_id, value) = match result { - Ok(t) => t, - Err(_) => continue, - }; - - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - warn!("{:?}: {}", value, e); - Error::BadServerResponse("Invalid PDU in send_join response.") - })?; - - db.rooms.add_pdu_outlier(&event_id, &value)?; - if let Some(state_key) = &pdu.state_key { - let shortstatekey = db.rooms.get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - state_key, - &db.globals, - )?; - state.insert(shortstatekey, pdu.event_id.clone()); - } - } - - let incoming_shortstatekey = db.rooms.get_or_create_shortstatekey( - &parsed_pdu.kind.to_string().into(), - parsed_pdu - .state_key - .as_ref() - .expect("Pdu is a membership state event"), - &db.globals, - )?; - - state.insert(incoming_shortstatekey, parsed_pdu.event_id.clone()); - - let create_shortstatekey = db - .rooms - .get_shortstatekey(&StateEventType::RoomCreate, "")? - .expect("Room exists"); - - if state.get(&create_shortstatekey).is_none() { - return Err(Error::BadServerResponse("State contained no create event.")); - } - - db.rooms.force_state( - room_id, - state - .into_iter() - .map(|(k, id)| db.rooms.compress_state_event(k, &id, &db.globals)) - .collect::<Result<_>>()?, - db, - )?; - - for result in send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, db)) - { - let (event_id, value) = match result { - Ok(t) => t, - Err(_) => continue, - }; - - db.rooms.add_pdu_outlier(&event_id, &value)?; - } - - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = db.rooms.append_to_state(&parsed_pdu, &db.globals)?; - - db.rooms.append_pdu( - &parsed_pdu, - join_event, - iter::once(&*parsed_pdu.event_id), - db, - )?; - - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - db.rooms.set_room_state(room_id, statehashid)?; - } else { - let event = RoomMemberEventContent { - membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, - reason: None, - join_authorized_via_users_server: None, - }; - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - }, - sender_user, - room_id, - db, - &state_lock, - )?; - } - - drop(state_lock); - - db.flush()?; - - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) -} - -fn validate_and_add_event_id( - pdu: &RawJsonValue, - room_version: &RoomVersionId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, -) -> Result<(Box<EventId>, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&value, room_version) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { - Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; - - if let Some((time, tries)) = db - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(&event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } - - if let Err(e) = ruma::signatures::verify_event( - &*pub_key_map - .read() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?, - &value, - room_version, - ) { - warn!("Event {} failed verification {:?} {}", event_id, pdu, e); - back_off(event_id); - return Err(Error::BadServerResponse("Event failed verification.")); - } - - value.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - Ok((event_id, value)) -} - -pub(crate) async fn invite_helper<'a>( - sender_user: &UserId, - user_id: &UserId, - room_id: &RoomId, - db: &Database, - is_direct: bool, -) -> Result<()> { - if user_id.server_name() != db.globals.server_name() { - let (room_version_id, pdu_json, invite_room_state) = { - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - let prev_events: Vec<_> = db - .rooms - .get_pdu_leaves(room_id)? - .into_iter() - .take(20) - .collect(); - - let create_event = db - .rooms - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option<RoomCreateEventContent> = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - - // If there was no create event yet, assume we are creating a room with the default - // version right now - let room_version_id = create_event_content - .map_or(db.globals.default_room_version(), |create_event| { - create_event.room_version - }); - let room_version = - RoomVersion::new(&room_version_id).expect("room version is supported"); - - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: None, - displayname: None, - is_direct: Some(is_direct), - membership: MembershipState::Invite, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("member event is valid value"); - - let state_key = user_id.to_string(); - let kind = StateEventType::RoomMember; - - let auth_events = db.rooms.get_auth_events( - room_id, - &kind.to_string().into(), - sender_user, - Some(&state_key), - &content, - )?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = BTreeMap::new(); - - if let Some(prev_pdu) = db.rooms.room_state_get(room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); - unsigned.insert( - "prev_sender".to_owned(), - to_raw_value(&prev_pdu.sender).expect("UserId is valid"), - ); - } - - let pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: room_id.to_owned(), - sender: sender_user.to_owned(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind: kind.to_string().into(), - content, - state_key: Some(state_key), - prev_events, - depth, - auth_events: auth_events - .iter() - .map(|(_, pdu)| pdu.event_id.clone()) - .collect(), - redacts: None, - unsigned: if unsigned.is_empty() { - None - } else { - Some(to_raw_value(&unsigned).expect("to_raw_value always works")) - }, - hashes: EventHash { - sha256: "aaa".to_owned(), - }, - signatures: None, - }; - - let auth_check = state_res::auth_check( - &room_version, - &pdu, - None::<PduEvent>, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|e| { - error!("{:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - to_canonical_value(db.globals.server_name()) - .expect("server name is a valid CanonicalJsonValue"), - ); - - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; - - drop(state_lock); - - (room_version_id, pdu_json, invite_room_state) - }; - - // Generate event id - let expected_event_id = format!( - "${}", - ruma::signatures::reference_hash(&pdu_json, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let expected_event_id = <&EventId>::try_from(expected_event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); - - let response = db - .sending - .send_federation_request( - &db.globals, - user_id.server_name(), - create_invite::v2::Request { - room_id, - event_id: expected_event_id, - room_version: &room_version_id, - event: &PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), - invite_room_state: &invite_room_state, - }, - ) - .await?; - - let pub_key_map = RwLock::new(BTreeMap::new()); - - // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&response.event, &db) - { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; - - if expected_event_id != event_id { - warn!("Server {} changed invite event, that's not allowed in the spec: ours: {:?}, theirs: {:?}", user_id.server_name(), pdu_json, value); - } - - let origin: Box<ServerName> = serde_json::from_value( - serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event needs an origin field.", - ))?) - .expect("CanonicalJson is valid json value"), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - - let pdu_id = server_server::handle_incoming_pdu( - &origin, - &event_id, - room_id, - value, - true, - db, - &pub_key_map, - ) - .await - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; - - let servers = db - .rooms - .room_servers(room_id) - .filter_map(|r| r.ok()) - .filter(|server| &**server != db.globals.server_name()); - - db.sending.send_pdu(servers, &pdu_id)?; - - return Ok(()); - } - - if !db.rooms.is_joined(sender_user, &room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this room.", - )); - } - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: db.users.displayname(user_id)?, - avatar_url: db.users.avatar_url(user_id)?, - is_direct: Some(is_direct), - third_party_invite: None, - blurhash: db.users.blurhash(user_id)?, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - sender_user, - room_id, - db, - &state_lock, - )?; - - drop(state_lock); - - Ok(()) -} diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs deleted file mode 100644 index 91988a4..0000000 --- a/src/client_server/read_marker.rs +++ /dev/null @@ -1,127 +0,0 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; -use ruma::{ - api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, - events::RoomAccountDataEventType, - receipt::ReceiptType, - MilliSecondsSinceUnixEpoch, -}; -use std::collections::BTreeMap; - -/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` -/// -/// Sets different types of read markers. -/// -/// - Updates fully-read account data event to `fully_read` -/// - If `read_receipt` is set: Update private marker and public read receipt EDU -pub async fn set_read_marker_route( - db: DatabaseGuard, - body: Ruma<set_read_marker::v3::IncomingRequest>, -) -> Result<set_read_marker::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let fully_read_event = ruma::events::fully_read::FullyReadEvent { - content: ruma::events::fully_read::FullyReadEventContent { - event_id: body.fully_read.clone(), - }, - }; - db.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &fully_read_event, - &db.globals, - )?; - - if let Some(event) = &body.read_receipt { - db.rooms.edus.private_read_set( - &body.room_id, - sender_user, - db.rooms.get_pdu_count(event)?.ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event does not exist.", - ))?, - &db.globals, - )?; - db.rooms - .reset_notification_counts(sender_user, &body.room_id)?; - - let mut user_receipts = BTreeMap::new(); - user_receipts.insert( - sender_user.clone(), - ruma::events::receipt::Receipt { - ts: Some(MilliSecondsSinceUnixEpoch::now()), - }, - ); - - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); - - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(event.to_owned(), receipts); - - db.rooms.edus.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - &db.globals, - )?; - } - - db.flush()?; - - Ok(set_read_marker::v3::Response {}) -} - -/// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}` -/// -/// Sets private read marker and public read receipt EDU. -pub async fn create_receipt_route( - db: DatabaseGuard, - body: Ruma<create_receipt::v3::IncomingRequest>, -) -> Result<create_receipt::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - db.rooms.edus.private_read_set( - &body.room_id, - sender_user, - db.rooms - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event does not exist.", - ))?, - &db.globals, - )?; - db.rooms - .reset_notification_counts(sender_user, &body.room_id)?; - - let mut user_receipts = BTreeMap::new(); - user_receipts.insert( - sender_user.clone(), - ruma::events::receipt::Receipt { - ts: Some(MilliSecondsSinceUnixEpoch::now()), - }, - ); - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); - - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(body.event_id.to_owned(), receipts); - - db.rooms.edus.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - &db.globals, - )?; - - db.flush()?; - - Ok(create_receipt::v3::Response {}) -} diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs deleted file mode 100644 index 98d895c..0000000 --- a/src/client_server/tag.rs +++ /dev/null @@ -1,117 +0,0 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; -use ruma::{ - api::client::tag::{create_tag, delete_tag, get_tags}, - events::{ - tag::{TagEvent, TagEventContent}, - RoomAccountDataEventType, - }, -}; -use std::collections::BTreeMap; - -/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` -/// -/// Adds a tag to the room. -/// -/// - Inserts the tag into the tag event of the room account data. -pub async fn update_tag_route( - db: DatabaseGuard, - body: Ruma<create_tag::v3::IncomingRequest>, -) -> Result<create_tag::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let mut tags_event = db - .account_data - .get( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - )? - .unwrap_or_else(|| TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }); - tags_event - .content - .tags - .insert(body.tag.clone().into(), body.tag_info.clone()); - - db.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &tags_event, - &db.globals, - )?; - - db.flush()?; - - Ok(create_tag::v3::Response {}) -} - -/// # `DELETE /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` -/// -/// Deletes a tag from the room. -/// -/// - Removes the tag from the tag event of the room account data. -pub async fn delete_tag_route( - db: DatabaseGuard, - body: Ruma<delete_tag::v3::IncomingRequest>, -) -> Result<delete_tag::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - let mut tags_event = db - .account_data - .get( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - )? - .unwrap_or_else(|| TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }); - tags_event.content.tags.remove(&body.tag.clone().into()); - - db.account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &tags_event, - &db.globals, - )?; - - db.flush()?; - - Ok(delete_tag::v3::Response {}) -} - -/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags` -/// -/// Returns tags on the room. -/// -/// - Gets the tag event of the room account data. -pub async fn get_tags_route( - db: DatabaseGuard, - body: Ruma<get_tags::v3::IncomingRequest>, -) -> Result<get_tags::v3::Response> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - - Ok(get_tags::v3::Response { - tags: db - .account_data - .get( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - )? - .unwrap_or_else(|| TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - .content - .tags, - }) -} diff --git a/src/config.rs b/src/config/mod.rs index 7d81d0f..31a586f 100644 --- a/src/config.rs +++ b/src/config/mod.rs @@ -4,7 +4,7 @@ use std::{ net::{IpAddr, Ipv4Addr}, }; -use ruma::{RoomVersionId, ServerName}; +use ruma::{OwnedServerName, RoomVersionId}; use serde::{de::IgnoredAny, Deserialize}; use tracing::warn; @@ -20,7 +20,7 @@ pub struct Config { pub port: u16, pub tls: Option<TlsConfig>, - pub server_name: Box<ServerName>, + pub server_name: OwnedServerName, #[serde(default = "default_database_backend")] pub database_backend: String, pub database_path: String, @@ -40,6 +40,8 @@ pub struct Config { pub max_request_size: u32, #[serde(default = "default_max_concurrent_requests")] pub max_concurrent_requests: u16, + #[serde(default = "default_max_fetch_prev_events")] + pub max_fetch_prev_events: u16, #[serde(default = "false_fn")] pub allow_registration: bool, #[serde(default = "true_fn")] @@ -60,7 +62,7 @@ pub struct Config { pub proxy: ProxyConfig, pub jwt_secret: Option<String>, #[serde(default = "Vec::new")] - pub trusted_servers: Vec<Box<ServerName>>, + pub trusted_servers: Vec<OwnedServerName>, #[serde(default = "default_log")] pub log: String, #[serde(default)] @@ -183,7 +185,7 @@ impl fmt::Display for Config { ("Turn TTL", &self.turn_ttl.to_string()), ("Turn URIs", { let mut lst = vec![]; - for item in self.turn_uris.to_vec().into_iter().enumerate() { + for item in self.turn_uris.iter().cloned().enumerate() { let (_, uri): (usize, String) = item; lst.push(uri); } @@ -191,13 +193,13 @@ impl fmt::Display for Config { }), ]; - let mut msg: String = "Active config values:\n\n".to_string(); + let mut msg: String = "Active config values:\n\n".to_owned(); for line in lines.into_iter().enumerate() { msg += &format!("{}: {}\n", line.1 .0, line.1 .1); } - write!(f, "{}", msg) + write!(f, "{msg}") } } @@ -222,7 +224,7 @@ fn default_database_backend() -> String { } fn default_db_cache_capacity_mb() -> f64 { - 10.0 + 1000.0 } fn default_conduit_cache_capacity_modifier() -> f64 { @@ -230,7 +232,7 @@ fn default_conduit_cache_capacity_modifier() -> f64 { } fn default_rocksdb_max_open_files() -> i32 { - 20 + 1000 } fn default_pdu_cache_capacity() -> u32 { @@ -238,7 +240,7 @@ fn default_pdu_cache_capacity() -> u32 { } fn default_cleanup_second_interval() -> u32 { - 1 * 60 // every minute + 60 // every minute } fn default_max_request_size() -> u32 { @@ -249,8 +251,12 @@ fn default_max_concurrent_requests() -> u16 { 100 } +fn default_max_fetch_prev_events() -> u16 { + 100_u16 +} + fn default_log() -> String { - "info,state_res=warn,_=off,sled=off".to_owned() + "warn,state_res=warn,_=off,sled=off".to_owned() } fn default_turn_ttl() -> u64 { @@ -258,6 +264,6 @@ fn default_turn_ttl() -> u64 { } // I know, it's a great name -fn default_default_room_version() -> RoomVersionId { - RoomVersionId::V6 +pub fn default_default_room_version() -> RoomVersionId { + RoomVersionId::V9 } diff --git a/src/database.rs b/src/database.rs deleted file mode 100644 index a0937c2..0000000 --- a/src/database.rs +++ /dev/null @@ -1,1017 +0,0 @@ -pub mod abstraction; - -pub mod account_data; -pub mod admin; -pub mod appservice; -pub mod globals; -pub mod key_backups; -pub mod media; -pub mod pusher; -pub mod rooms; -pub mod sending; -pub mod transaction_ids; -pub mod uiaa; -pub mod users; - -use self::admin::create_admin_room; -use crate::{utils, Config, Error, Result}; -use abstraction::DatabaseEngine; -use directories::ProjectDirs; -use futures_util::{stream::FuturesUnordered, StreamExt}; -use lru_cache::LruCache; -use ruma::{ - events::{ - push_rules::PushRulesEventContent, room::message::RoomMessageEventContent, - GlobalAccountDataEvent, GlobalAccountDataEventType, - }, - push::Ruleset, - DeviceId, EventId, RoomId, UserId, -}; -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - fs::{self, remove_dir_all}, - io::Write, - mem::size_of, - ops::Deref, - path::Path, - sync::{Arc, Mutex, RwLock}, -}; -use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; -use tracing::{debug, error, info, warn}; - -pub struct Database { - _db: Arc<dyn DatabaseEngine>, - pub globals: globals::Globals, - pub users: users::Users, - pub uiaa: uiaa::Uiaa, - pub rooms: rooms::Rooms, - pub account_data: account_data::AccountData, - pub media: media::Media, - pub key_backups: key_backups::KeyBackups, - pub transaction_ids: transaction_ids::TransactionIds, - pub sending: sending::Sending, - pub admin: admin::Admin, - pub appservice: appservice::Appservice, - pub pusher: pusher::PushData, -} - -impl Database { - /// Tries to remove the old database but ignores all errors. - pub fn try_remove(server_name: &str) -> Result<()> { - let mut path = ProjectDirs::from("xyz", "koesters", "conduit") - .ok_or_else(|| Error::bad_config("The OS didn't return a valid home directory path."))? - .data_dir() - .to_path_buf(); - path.push(server_name); - let _ = remove_dir_all(path); - - Ok(()) - } - - fn check_db_setup(config: &Config) -> Result<()> { - let path = Path::new(&config.database_path); - - let sled_exists = path.join("db").exists(); - let sqlite_exists = path.join("conduit.db").exists(); - let rocksdb_exists = path.join("IDENTITY").exists(); - - let mut count = 0; - - if sled_exists { - count += 1; - } - - if sqlite_exists { - count += 1; - } - - if rocksdb_exists { - count += 1; - } - - if count > 1 { - warn!("Multiple databases at database_path detected"); - return Ok(()); - } - - if sled_exists && config.database_backend != "sled" { - return Err(Error::bad_config( - "Found sled at database_path, but is not specified in config.", - )); - } - - if sqlite_exists && config.database_backend != "sqlite" { - return Err(Error::bad_config( - "Found sqlite at database_path, but is not specified in config.", - )); - } - - if rocksdb_exists && config.database_backend != "rocksdb" { - return Err(Error::bad_config( - "Found rocksdb at database_path, but is not specified in config.", - )); - } - - Ok(()) - } - - /// Load an existing database or create a new one. - pub async fn load_or_create(config: &Config) -> Result<Arc<TokioRwLock<Self>>> { - Self::check_db_setup(config)?; - - if !Path::new(&config.database_path).exists() { - std::fs::create_dir_all(&config.database_path) - .map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?; - } - - let builder: Arc<dyn DatabaseEngine> = match &*config.database_backend { - "sqlite" => { - #[cfg(not(feature = "sqlite"))] - return Err(Error::BadConfig("Database backend not found.")); - #[cfg(feature = "sqlite")] - Arc::new(Arc::<abstraction::sqlite::Engine>::open(config)?) - } - "rocksdb" => { - #[cfg(not(feature = "rocksdb"))] - return Err(Error::BadConfig("Database backend not found.")); - #[cfg(feature = "rocksdb")] - Arc::new(Arc::<abstraction::rocksdb::Engine>::open(config)?) - } - "persy" => { - #[cfg(not(feature = "persy"))] - return Err(Error::BadConfig("Database backend not found.")); - #[cfg(feature = "persy")] - Arc::new(Arc::<abstraction::persy::Engine>::open(config)?) - } - _ => { - return Err(Error::BadConfig("Database backend not found.")); - } - }; - - if config.max_request_size < 1024 { - eprintln!("ERROR: Max request size is less than 1KB. Please increase it."); - } - - let (admin_sender, admin_receiver) = mpsc::unbounded_channel(); - let (sending_sender, sending_receiver) = mpsc::unbounded_channel(); - - let db = Arc::new(TokioRwLock::from(Self { - _db: builder.clone(), - users: users::Users { - userid_password: builder.open_tree("userid_password")?, - userid_displayname: builder.open_tree("userid_displayname")?, - userid_avatarurl: builder.open_tree("userid_avatarurl")?, - userid_blurhash: builder.open_tree("userid_blurhash")?, - userdeviceid_token: builder.open_tree("userdeviceid_token")?, - userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, - userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, - token_userdeviceid: builder.open_tree("token_userdeviceid")?, - onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, - userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, - keychangeid_userid: builder.open_tree("keychangeid_userid")?, - keyid_key: builder.open_tree("keyid_key")?, - userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, - userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, - userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, - userfilterid_filter: builder.open_tree("userfilterid_filter")?, - todeviceid_events: builder.open_tree("todeviceid_events")?, - }, - uiaa: uiaa::Uiaa { - userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, - userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - }, - rooms: rooms::Rooms { - edus: rooms::RoomEdus { - readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, - roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt - roomuserid_lastprivatereadupdate: builder - .open_tree("roomuserid_lastprivatereadupdate")?, - typingid_userid: builder.open_tree("typingid_userid")?, - roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, - presenceid_presence: builder.open_tree("presenceid_presence")?, - userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, - }, - pduid_pdu: builder.open_tree("pduid_pdu")?, - eventid_pduid: builder.open_tree("eventid_pduid")?, - roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, - - alias_roomid: builder.open_tree("alias_roomid")?, - aliasid_alias: builder.open_tree("aliasid_alias")?, - publicroomids: builder.open_tree("publicroomids")?, - - tokenids: builder.open_tree("tokenids")?, - - roomserverids: builder.open_tree("roomserverids")?, - serverroomids: builder.open_tree("serverroomids")?, - userroomid_joined: builder.open_tree("userroomid_joined")?, - roomuserid_joined: builder.open_tree("roomuserid_joined")?, - roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, - roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, - roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, - userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, - roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, - userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, - roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, - - disabledroomids: builder.open_tree("disabledroomids")?, - - lazyloadedids: builder.open_tree("lazyloadedids")?, - - userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, - userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, - - statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, - shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, - - shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, - - roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, - - shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, - eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, - shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, - shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, - roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, - roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, - statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, - - eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, - softfailedeventids: builder.open_tree("softfailedeventids")?, - - referencedevents: builder.open_tree("referencedevents")?, - pdu_cache: Mutex::new(LruCache::new( - config - .pdu_cache_capacity - .try_into() - .expect("pdu cache capacity fits into usize"), - )), - auth_chain_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - shorteventid_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - eventidshort_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - shortstatekey_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - statekeyshort_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - our_real_users_cache: RwLock::new(HashMap::new()), - appservice_in_room_cache: RwLock::new(HashMap::new()), - lazy_load_waiting: Mutex::new(HashMap::new()), - stateinfo_cache: Mutex::new(LruCache::new( - (100.0 * config.conduit_cache_capacity_modifier) as usize, - )), - lasttimelinecount_cache: Mutex::new(HashMap::new()), - }, - account_data: account_data::AccountData { - roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, - roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, - }, - media: media::Media { - mediaid_file: builder.open_tree("mediaid_file")?, - }, - key_backups: key_backups::KeyBackups { - backupid_algorithm: builder.open_tree("backupid_algorithm")?, - backupid_etag: builder.open_tree("backupid_etag")?, - backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, - }, - transaction_ids: transaction_ids::TransactionIds { - userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, - }, - sending: sending::Sending { - servername_educount: builder.open_tree("servername_educount")?, - servernameevent_data: builder.open_tree("servernameevent_data")?, - servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, - maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), - sender: sending_sender, - }, - admin: admin::Admin { - sender: admin_sender, - }, - appservice: appservice::Appservice { - cached_registrations: Arc::new(RwLock::new(HashMap::new())), - id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, - }, - pusher: pusher::PushData { - senderkey_pusher: builder.open_tree("senderkey_pusher")?, - }, - globals: globals::Globals::load( - builder.open_tree("global")?, - builder.open_tree("server_signingkeys")?, - config.clone(), - )?, - })); - - let guard = db.read().await; - - // Matrix resource ownership is based on the server name; changing it - // requires recreating the database from scratch. - if guard.users.count()? > 0 { - let conduit_user = - UserId::parse_with_server_name("conduit", guard.globals.server_name()) - .expect("@conduit:server_name is valid"); - - if !guard.users.exists(&conduit_user)? { - error!( - "The {} server user does not exist, and the database is not new.", - conduit_user - ); - return Err(Error::bad_database( - "Cannot reuse an existing database after changing the server name, please delete the old one first." - )); - } - } - - // If the database has any data, perform data migrations before starting - let latest_database_version = 11; - - if guard.users.count()? > 0 { - let db = &*guard; - // MIGRATIONS - if db.globals.database_version()? < 1 { - for (roomserverid, _) in db.rooms.roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xff); - let room_id = parts.next().expect("split always returns one element"); - let servername = match parts.next() { - Some(s) => s, - None => { - error!("Migration: Invalid roomserverid in db."); - continue; - } - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xff); - serverroomid.extend_from_slice(room_id); - - db.rooms.serverroomids.insert(&serverroomid, &[])?; - } - - db.globals.bump_database_version(1)?; - - warn!("Migration: 0 -> 1 finished"); - } - - if db.globals.database_version()? < 2 { - // We accidentally inserted hashed versions of "" into the db instead of just "" - for (userid, password) in db.users.userid_password.iter() { - let password = utils::string_from_bytes(&password); - - let empty_hashed_password = password.map_or(false, |password| { - argon2::verify_encoded(&password, b"").unwrap_or(false) - }); - - if empty_hashed_password { - db.users.userid_password.insert(&userid, b"")?; - } - } - - db.globals.bump_database_version(2)?; - - warn!("Migration: 1 -> 2 finished"); - } - - if db.globals.database_version()? < 3 { - // Move media to filesystem - for (key, content) in db.media.mediaid_file.iter() { - if content.is_empty() { - continue; - } - - let path = db.globals.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - db.media.mediaid_file.insert(&key, &[])?; - } - - db.globals.bump_database_version(3)?; - - warn!("Migration: 2 -> 3 finished"); - } - - if db.globals.database_version()? < 4 { - // Add federated users to db as deactivated - for our_user in db.users.iter() { - let our_user = our_user?; - if db.users.is_deactivated(&our_user)? { - continue; - } - for room in db.rooms.rooms_joined(&our_user) { - for user in db.rooms.room_members(&room?) { - let user = user?; - if user.server_name() != db.globals.server_name() { - println!("Migration: Creating user {}", user); - db.users.create(&user, None)?; - } - } - } - } - - db.globals.bump_database_version(4)?; - - warn!("Migration: 3 -> 4 finished"); - } - - if db.globals.database_version()? < 5 { - // Upgrade user data store - for (roomuserdataid, _) in db.account_data.roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xff); - let room_id = parts.next().unwrap(); - let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap(); - - let mut key = room_id.to_vec(); - key.push(0xff); - key.extend_from_slice(user_id); - key.push(0xff); - key.extend_from_slice(event_type); - - db.account_data - .roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - } - - db.globals.bump_database_version(5)?; - - warn!("Migration: 4 -> 5 finished"); - } - - if db.globals.database_version()? < 6 { - // Set room member count - for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { - let string = utils::string_from_bytes(&roomid).unwrap(); - let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - db.rooms.update_joined_count(room_id, &db)?; - } - - db.globals.bump_database_version(6)?; - - warn!("Migration: 5 -> 6 finished"); - } - - if db.globals.database_version()? < 7 { - // Upgrade state store - let mut last_roomstates: HashMap<Box<RoomId>, u64> = HashMap::new(); - let mut current_sstatehash: Option<u64> = None; - let mut current_room = None; - let mut current_state = HashSet::new(); - let mut counter = 0; - - let mut handle_state = - |current_sstatehash: u64, - current_room: &RoomId, - current_state: HashSet<_>, - last_roomstates: &mut HashMap<_, _>| { - counter += 1; - println!("counter: {}", counter); - let last_roomsstatehash = last_roomstates.get(current_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |&last_roomsstatehash| { - db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash)) - }, - )?; - - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .copied() - .collect::<HashSet<_>>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .copied() - .collect::<HashSet<_>>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - db.rooms.save_state_from_diff( - dbg!(current_sstatehash), - statediffnew, - statediffremoved, - 2, // every state change is 2 event changes on average - states_parents, - )?; - - /* - let mut tmp = db.rooms.load_shortstatehash_info(¤t_sstatehash, &db)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) - .collect::<Vec<_>>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) - .collect::<Vec<_>>() - ); - */ - - Ok::<_, Error>(()) - }; - - for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]) - .expect("number of bytes is correct"); - let sstatekey = k[size_of::<u64>()..].to_vec(); - if Some(sstatehash) != current_sstatehash { - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - last_roomstates - .insert(current_room.clone().unwrap(), current_sstatehash); - } - current_state = HashSet::new(); - current_sstatehash = Some(sstatehash); - - let event_id = db - .rooms - .shorteventid_eventid - .get(&seventid) - .unwrap() - .unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = db.rooms.get_pdu(event_id).unwrap().unwrap(); - - if Some(&pdu.room_id) != current_room.as_ref() { - current_room = Some(pdu.room_id.clone()); - } - } - - let mut val = sstatekey; - val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); - } - - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - } - - db.globals.bump_database_version(7)?; - - warn!("Migration: 6 -> 7 finished"); - } - - if db.globals.database_version()? < 8 { - // Generate short room ids for all rooms - for (room_id, _) in db.rooms.roomid_shortstatehash.iter() { - let shortroomid = db.globals.next_count()?.to_be_bytes(); - db.rooms.roomid_shortroomid.insert(&room_id, &shortroomid)?; - info!("Migration: 8"); - } - // Update pduids db layout - let mut batch = db.rooms.pduid_pdu.iter().filter_map(|(key, v)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(2, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = db - .rooms - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_key = short_room_id; - new_key.extend_from_slice(count); - - Some((new_key, v)) - }); - - db.rooms.pduid_pdu.insert_batch(&mut batch)?; - - let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = db - .rooms - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_value = short_room_id; - new_value.extend_from_slice(count); - - Some((k, new_value)) - }); - - db.rooms.eventid_pduid.insert_batch(&mut batch2)?; - - db.globals.bump_database_version(8)?; - - warn!("Migration: 7 -> 8 finished"); - } - - if db.globals.database_version()? < 9 { - // Update tokenids db layout - let mut iter = db - .rooms - .tokenids - .iter() - .filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); - - let short_room_id = db - .rooms - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id; - new_key.extend_from_slice(word); - new_key.push(0xff); - new_key.extend_from_slice(pdu_id_count); - println!("old {:?}", key); - println!("new {:?}", new_key); - Some((new_key, Vec::new())) - }) - .peekable(); - - while iter.peek().is_some() { - db.rooms - .tokenids - .insert_batch(&mut iter.by_ref().take(1000))?; - println!("smaller batch done"); - } - - info!("Deleting starts"); - - let batch2: Vec<_> = db - .rooms - .tokenids - .iter() - .filter_map(|(key, _)| { - if key.starts_with(b"!") { - println!("del {:?}", key); - Some(key) - } else { - None - } - }) - .collect(); - - for key in batch2 { - println!("del"); - db.rooms.tokenids.remove(&key)?; - } - - db.globals.bump_database_version(9)?; - - warn!("Migration: 8 -> 9 finished"); - } - - if db.globals.database_version()? < 10 { - // Add other direction for shortstatekeys - for (statekey, shortstatekey) in db.rooms.statekey_shortstatekey.iter() { - db.rooms - .shortstatekey_statekey - .insert(&shortstatekey, &statekey)?; - } - - // Force E2EE device list updates so we can send them over federation - for user_id in db.users.iter().filter_map(|r| r.ok()) { - db.users - .mark_device_key_update(&user_id, &db.rooms, &db.globals)?; - } - - db.globals.bump_database_version(10)?; - - warn!("Migration: 9 -> 10 finished"); - } - - if db.globals.database_version()? < 11 { - db._db - .open_tree("userdevicesessionid_uiaarequest")? - .clear()?; - db.globals.bump_database_version(11)?; - - warn!("Migration: 10 -> 11 finished"); - } - - assert_eq!(11, latest_database_version); - - info!( - "Loaded {} database with version {}", - config.database_backend, latest_database_version - ); - } else { - guard - .globals - .bump_database_version(latest_database_version)?; - - // Create the admin room and server user on first run - create_admin_room(&guard).await?; - - warn!( - "Created new {} database with version {}", - config.database_backend, latest_database_version - ); - } - - // This data is probably outdated - guard.rooms.edus.presenceid_presence.clear()?; - - guard.admin.start_handler(Arc::clone(&db), admin_receiver); - - // Set emergency access for the conduit user - match set_emergency_access(&guard) { - Ok(pwd_set) => { - if pwd_set { - warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"); - guard.admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!")); - } - } - Err(e) => { - error!( - "Could not set the configured emergency password for the conduit user: {}", - e - ) - } - }; - - guard - .sending - .start_handler(Arc::clone(&db), sending_receiver); - - drop(guard); - - Self::start_cleanup_task(Arc::clone(&db), config).await; - - Ok(db) - } - - #[cfg(feature = "conduit_bin")] - pub async fn on_shutdown(db: Arc<TokioRwLock<Self>>) { - info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); - db.read().await.globals.rotate.fire(); - } - - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) { - let userid_bytes = user_id.as_bytes().to_vec(); - let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xff); - - let mut userdeviceid_prefix = userid_prefix.clone(); - userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xff); - - let mut futures = FuturesUnordered::new(); - - // Return when *any* user changed his key - // TODO: only send for user they share a room with - futures.push( - self.users - .todeviceid_events - .watch_prefix(&userdeviceid_prefix), - ); - - futures.push(self.rooms.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push( - self.rooms - .userroomid_invitestate - .watch_prefix(&userid_prefix), - ); - futures.push(self.rooms.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push( - self.rooms - .userroomid_notificationcount - .watch_prefix(&userid_prefix), - ); - futures.push( - self.rooms - .userroomid_highlightcount - .watch_prefix(&userid_prefix), - ); - - // Events for rooms we are in - for room_id in self.rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { - let short_roomid = self - .rooms - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let roomid_bytes = room_id.as_bytes().to_vec(); - let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xff); - - // PDUs - futures.push(self.rooms.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push( - self.rooms - .edus - .roomid_lasttypingupdate - .watch_prefix(&roomid_bytes), - ); - - futures.push( - self.rooms - .edus - .readreceiptid_readreceipt - .watch_prefix(&roomid_prefix), - ); - - // Key changes - futures.push(self.users.keychangeid_userid.watch_prefix(&roomid_prefix)); - - // Room account data - let mut roomuser_prefix = roomid_prefix.clone(); - roomuser_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.account_data - .roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), - ); - } - - let mut globaluserdata_prefix = vec![0xff]; - globaluserdata_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.account_data - .roomusertype_roomuserdataid - .watch_prefix(&globaluserdata_prefix), - ); - - // More key changes (used when user is not joined to any rooms) - futures.push(self.users.keychangeid_userid.watch_prefix(&userid_prefix)); - - // One time keys - futures.push( - self.users - .userid_lastonetimekeyupdate - .watch_prefix(&userid_bytes), - ); - - futures.push(Box::pin(self.globals.rotate.watch())); - - // Wait until one of them finds something - futures.next().await; - } - - #[tracing::instrument(skip(self))] - pub fn flush(&self) -> Result<()> { - let start = std::time::Instant::now(); - - let res = self._db.flush(); - - debug!("flush: took {:?}", start.elapsed()); - - res - } - - #[tracing::instrument(skip(db, config))] - pub async fn start_cleanup_task(db: Arc<TokioRwLock<Self>>, config: &Config) { - use tokio::time::interval; - - #[cfg(unix)] - use tokio::signal::unix::{signal, SignalKind}; - use tracing::info; - - use std::time::{Duration, Instant}; - - let timer_interval = Duration::from_secs(config.cleanup_second_interval as u64); - - tokio::spawn(async move { - let mut i = interval(timer_interval); - #[cfg(unix)] - let mut s = signal(SignalKind::hangup()).unwrap(); - - loop { - #[cfg(unix)] - tokio::select! { - _ = i.tick() => { - info!("cleanup: Timer ticked"); - } - _ = s.recv() => { - info!("cleanup: Received SIGHUP"); - } - }; - #[cfg(not(unix))] - { - i.tick().await; - info!("cleanup: Timer ticked") - } - - let start = Instant::now(); - if let Err(e) = db.read().await._db.cleanup() { - error!("cleanup: Errored: {}", e); - } else { - info!("cleanup: Finished in {:?}", start.elapsed()); - } - } - }); - } -} - -/// Sets the emergency password and push rules for the @conduit account in case emergency password is set -fn set_emergency_access(db: &Database) -> Result<bool> { - let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) - .expect("@conduit:server_name is a valid UserId"); - - db.users - .set_password(&conduit_user, db.globals.emergency_password().as_deref())?; - - let (ruleset, res) = match db.globals.emergency_password() { - Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), - None => (Ruleset::new(), Ok(false)), - }; - - db.account_data.update( - None, - &conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &GlobalAccountDataEvent { - content: PushRulesEventContent { global: ruleset }, - }, - &db.globals, - )?; - - res -} - -pub struct DatabaseGuard(OwnedRwLockReadGuard<Database>); - -impl Deref for DatabaseGuard { - type Target = OwnedRwLockReadGuard<Database>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[cfg(feature = "conduit_bin")] -#[axum::async_trait] -impl<B> axum::extract::FromRequest<B> for DatabaseGuard -where - B: Send, -{ - type Rejection = axum::extract::rejection::ExtensionRejection; - - async fn from_request( - req: &mut axum::extract::RequestParts<B>, - ) -> Result<Self, Self::Rejection> { - use axum::extract::Extension; - - let Extension(db): Extension<Arc<TokioRwLock<Database>>> = - Extension::from_request(req).await?; - - Ok(DatabaseGuard(db.read_owned().await)) - } -} - -impl From<OwnedRwLockReadGuard<Database>> for DatabaseGuard { - fn from(val: OwnedRwLockReadGuard<Database>) -> Self { - Self(val) - } -} diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 74f3a45..93660f9 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -26,11 +26,11 @@ pub mod persy; ))] pub mod watchers; -pub trait DatabaseEngine: Send + Sync { +pub trait KeyValueDatabaseEngine: Send + Sync { fn open(config: &Config) -> Result<Self> where Self: Sized; - fn open_tree(&self, name: &'static str) -> Result<Arc<dyn Tree>>; + fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>; fn flush(&self) -> Result<()>; fn cleanup(&self) -> Result<()> { Ok(()) @@ -40,7 +40,7 @@ pub trait DatabaseEngine: Send + Sync { } } -pub trait Tree: Send + Sync { +pub trait KvTree: Send + Sync { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; diff --git a/src/database/abstraction/persy.rs b/src/database/abstraction/persy.rs index e78e731..1fa7a0d 100644 --- a/src/database/abstraction/persy.rs +++ b/src/database/abstraction/persy.rs @@ -1,6 +1,6 @@ use crate::{ database::{ - abstraction::{watchers::Watchers, DatabaseEngine, Tree}, + abstraction::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}, Config, }, Result, @@ -15,7 +15,7 @@ pub struct Engine { persy: Persy, } -impl DatabaseEngine for Arc<Engine> { +impl KeyValueDatabaseEngine for Arc<Engine> { fn open(config: &Config) -> Result<Self> { let mut cfg = persy::Config::new(); cfg.change_cache_size((config.db_cache_capacity_mb * 1024.0 * 1024.0) as u64); @@ -27,7 +27,7 @@ impl DatabaseEngine for Arc<Engine> { Ok(Arc::new(Engine { persy })) } - fn open_tree(&self, name: &'static str) -> Result<Arc<dyn Tree>> { + fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> { // Create if it doesn't exist if !self.persy.exists_index(name)? { let mut tx = self.persy.begin()?; @@ -61,7 +61,7 @@ impl PersyTree { } } -impl Tree for PersyTree { +impl KvTree for PersyTree { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { let result = self .persy diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 2cf9d5e..34d91d2 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -1,4 +1,4 @@ -use super::{super::Config, watchers::Watchers, DatabaseEngine, Tree}; +use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{utils, Result}; use std::{ future::Future, @@ -51,7 +51,7 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O db_opts } -impl DatabaseEngine for Arc<Engine> { +impl KeyValueDatabaseEngine for Arc<Engine> { fn open(config: &Config) -> Result<Self> { let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes).unwrap(); @@ -83,7 +83,7 @@ impl DatabaseEngine for Arc<Engine> { })) } - fn open_tree(&self, name: &'static str) -> Result<Arc<dyn Tree>> { + fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> { if !self.old_cfs.contains(&name.to_owned()) { // Create if it didn't exist let _ = self @@ -129,7 +129,7 @@ impl RocksDbEngineTree<'_> { } } -impl Tree for RocksDbEngineTree<'_> { +impl KvTree for RocksDbEngineTree<'_> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) } @@ -161,6 +161,7 @@ impl Tree for RocksDbEngineTree<'_> { self.db .rocks .iterator_cf(&self.cf(), rocksdb::IteratorMode::Start) + //.map(|r| r.unwrap()) .map(|(k, v)| (Vec::from(k), Vec::from(v))), ) } @@ -184,6 +185,7 @@ impl Tree for RocksDbEngineTree<'_> { }, ), ) + //.map(|r| r.unwrap()) .map(|(k, v)| (Vec::from(k), Vec::from(v))), ) } @@ -191,7 +193,7 @@ impl Tree for RocksDbEngineTree<'_> { fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { let lock = self.write_lock.write().unwrap(); - let old = self.db.rocks.get_cf(&self.cf(), &key)?; + let old = self.db.rocks.get_cf(&self.cf(), key)?; let new = utils::increment(old.as_deref()).unwrap(); self.db.rocks.put_cf(&self.cf(), key, &new)?; @@ -224,6 +226,7 @@ impl Tree for RocksDbEngineTree<'_> { &self.cf(), rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), ) + //.map(|r| r.unwrap()) .map(|(k, v)| (Vec::from(k), Vec::from(v))) .take_while(move |(k, _)| k.starts_with(&prefix)), ) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 7cfa81a..b69efb6 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,4 +1,4 @@ -use super::{watchers::Watchers, DatabaseEngine, Tree}; +use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{database::Config, Result}; use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; @@ -48,13 +48,13 @@ pub struct Engine { impl Engine { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result<Connection> { - let conn = Connection::open(&path)?; + let conn = Connection::open(path)?; - conn.pragma_update(Some(Main), "page_size", &2048)?; - conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; - conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; - conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; - conn.pragma_update(Some(Main), "wal_autocheckpoint", &0)?; + conn.pragma_update(Some(Main), "page_size", 2048)?; + conn.pragma_update(Some(Main), "journal_mode", "WAL")?; + conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; + conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?; + conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; Ok(conn) } @@ -75,12 +75,12 @@ impl Engine { pub fn flush_wal(self: &Arc<Self>) -> Result<()> { self.write_lock() - .pragma_update(Some(Main), "wal_checkpoint", &"RESTART")?; + .pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; Ok(()) } } -impl DatabaseEngine for Arc<Engine> { +impl KeyValueDatabaseEngine for Arc<Engine> { fn open(config: &Config) -> Result<Self> { let path = Path::new(&config.database_path).join("conduit.db"); @@ -105,8 +105,8 @@ impl DatabaseEngine for Arc<Engine> { Ok(arc) } - fn open_tree(&self, name: &str) -> Result<Arc<dyn Tree>> { - self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; + fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> { + self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?; Ok(Arc::new(SqliteTable { engine: Arc::clone(self), @@ -135,7 +135,6 @@ type TupleOfBytes = (Vec<u8>, Vec<u8>); impl SqliteTable { fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result<Option<Vec<u8>>> { - //dbg!(&self.name); Ok(guard .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? .query_row([key], |row| row.get(0)) @@ -143,7 +142,6 @@ impl SqliteTable { } fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { - //dbg!(&self.name); guard.execute( format!( "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", @@ -176,10 +174,7 @@ impl SqliteTable { statement .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() - .map(move |r| { - //dbg!(&name); - r.unwrap() - }), + .map(move |r| r.unwrap()), ); Box::new(PreparedStatementIterator { @@ -189,7 +184,7 @@ impl SqliteTable { } } -impl Tree for SqliteTable { +impl KvTree for SqliteTable { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { self.get_with_guard(self.engine.read_lock(), key) } @@ -276,10 +271,7 @@ impl Tree for SqliteTable { statement .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() - .map(move |r| { - //dbg!(&name); - r.unwrap() - }), + .map(move |r| r.unwrap()), ); Box::new(PreparedStatementIterator { iterator, @@ -301,10 +293,7 @@ impl Tree for SqliteTable { statement .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) .unwrap() - .map(move |r| { - //dbg!(&name); - r.unwrap() - }), + .map(move |r| r.unwrap()), ); Box::new(PreparedStatementIterator { diff --git a/src/database/admin.rs b/src/database/admin.rs deleted file mode 100644 index edc7691..0000000 --- a/src/database/admin.rs +++ /dev/null @@ -1,1145 +0,0 @@ -use std::{ - collections::BTreeMap, - convert::{TryFrom, TryInto}, - sync::Arc, - time::Instant, -}; - -use crate::{ - client_server::AUTO_GEN_PASSWORD_LENGTH, - error::{Error, Result}, - pdu::PduBuilder, - server_server, utils, - utils::HtmlEscape, - Database, PduEvent, -}; -use clap::Parser; -use regex::Regex; -use ruma::{ - events::{ - room::{ - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - message::RoomMessageEventContent, - name::RoomNameEventContent, - power_levels::RoomPowerLevelsEventContent, - topic::RoomTopicEventContent, - }, - RoomEventType, - }, - EventId, RoomAliasId, RoomId, RoomName, RoomVersionId, ServerName, UserId, -}; -use serde_json::value::to_raw_value; -use tokio::sync::{mpsc, MutexGuard, RwLock, RwLockReadGuard}; - -#[derive(Debug)] -pub enum AdminRoomEvent { - ProcessMessage(String), - SendMessage(RoomMessageEventContent), -} - -#[derive(Clone)] -pub struct Admin { - pub sender: mpsc::UnboundedSender<AdminRoomEvent>, -} - -impl Admin { - pub fn start_handler( - &self, - db: Arc<RwLock<Database>>, - mut receiver: mpsc::UnboundedReceiver<AdminRoomEvent>, - ) { - tokio::spawn(async move { - // TODO: Use futures when we have long admin commands - //let mut futures = FuturesUnordered::new(); - - let guard = db.read().await; - - let conduit_user = UserId::parse(format!("@conduit:{}", guard.globals.server_name())) - .expect("@conduit:server_name is valid"); - - let conduit_room = guard - .rooms - .id_from_alias( - format!("#admins:{}", guard.globals.server_name()) - .as_str() - .try_into() - .expect("#admins:server_name is a valid room alias"), - ) - .expect("Database data for admin room alias must be valid") - .expect("Admin room must exist"); - - drop(guard); - - let send_message = |message: RoomMessageEventContent, - guard: RwLockReadGuard<'_, Database>, - mutex_lock: &MutexGuard<'_, ()>| { - guard - .rooms - .build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMessage, - content: to_raw_value(&message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &guard, - mutex_lock, - ) - .unwrap(); - }; - - loop { - tokio::select! { - Some(event) = receiver.recv() => { - let guard = db.read().await; - - let message_content = match event { - AdminRoomEvent::SendMessage(content) => content, - AdminRoomEvent::ProcessMessage(room_message) => process_admin_message(&*guard, room_message).await - }; - - let mutex_state = Arc::clone( - guard.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(conduit_room.clone()) - .or_default(), - ); - - let state_lock = mutex_state.lock().await; - - send_message(message_content, guard, &state_lock); - - drop(state_lock); - } - } - } - }); - } - - pub fn process_message(&self, room_message: String) { - self.sender - .send(AdminRoomEvent::ProcessMessage(room_message)) - .unwrap(); - } - - pub fn send_message(&self, message_content: RoomMessageEventContent) { - self.sender - .send(AdminRoomEvent::SendMessage(message_content)) - .unwrap(); - } -} - -// Parse and process a message from the admin room -async fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEventContent { - let mut lines = room_message.lines(); - let command_line = lines.next().expect("each string has at least one line"); - let body: Vec<_> = lines.collect(); - - let admin_command = match parse_admin_command(&command_line) { - Ok(command) => command, - Err(error) => { - let server_name = db.globals.server_name(); - let message = error - .to_string() - .replace("server.name", server_name.as_str()); - let html_message = usage_to_html(&message, server_name); - - return RoomMessageEventContent::text_html(message, html_message); - } - }; - - match process_admin_command(db, admin_command, body).await { - Ok(reply_message) => reply_message, - Err(error) => { - let markdown_message = format!( - "Encountered an error while handling the command:\n\ - ```\n{}\n```", - error, - ); - let html_message = format!( - "Encountered an error while handling the command:\n\ - <pre>\n{}\n</pre>", - error, - ); - - RoomMessageEventContent::text_html(markdown_message, html_message) - } - } -} - -// Parse chat messages from the admin room into an AdminCommand object -fn parse_admin_command(command_line: &str) -> std::result::Result<AdminCommand, String> { - // Note: argv[0] is `@conduit:servername:`, which is treated as the main command - let mut argv: Vec<_> = command_line.split_whitespace().collect(); - - // Replace `help command` with `command --help` - // Clap has a help subcommand, but it omits the long help description. - if argv.len() > 1 && argv[1] == "help" { - argv.remove(1); - argv.push("--help"); - } - - // Backwards compatibility with `register_appservice`-style commands - let command_with_dashes; - if argv.len() > 1 && argv[1].contains("_") { - command_with_dashes = argv[1].replace("_", "-"); - argv[1] = &command_with_dashes; - } - - AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) -} - -#[derive(Parser)] -#[clap(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))] -enum AdminCommand { - #[clap(verbatim_doc_comment)] - /// Register an appservice using its registration YAML - /// - /// This command needs a YAML generated by an appservice (such as a bridge), - /// which must be provided in a Markdown code-block below the command. - /// - /// Registering a new bridge using the ID of an existing bridge will replace - /// the old one. - /// - /// [commandbody] - /// # ``` - /// # yaml content here - /// # ``` - RegisterAppservice, - - /// Unregister an appservice using its ID - /// - /// You can find the ID using the `list-appservices` command. - UnregisterAppservice { - /// The appservice to unregister - appservice_identifier: String, - }, - - /// List all the currently registered appservices - ListAppservices, - - /// List all rooms the server knows about - ListRooms, - - /// List users in the database - ListLocalUsers, - - /// List all rooms we are currently handling an incoming pdu from - IncomingFederation, - - /// Deactivate a user - /// - /// User will not be removed from all rooms by default. - /// Use --leave-rooms to force the user to leave all rooms - DeactivateUser { - #[clap(short, long)] - leave_rooms: bool, - user_id: Box<UserId>, - }, - - #[clap(verbatim_doc_comment)] - /// Deactivate a list of users - /// - /// Recommended to use in conjunction with list-local-users. - /// - /// Users will not be removed from joined rooms by default. - /// Can be overridden with --leave-rooms flag. - /// Removing a mass amount of users from a room may cause a significant amount of leave events. - /// The time to leave rooms may depend significantly on joined rooms and servers. - /// - /// [commandbody] - /// # ``` - /// # User list here - /// # ``` - DeactivateAll { - #[clap(short, long)] - /// Remove users from their joined rooms - leave_rooms: bool, - #[clap(short, long)] - /// Also deactivate admin accounts - force: bool, - }, - - /// Get the auth_chain of a PDU - GetAuthChain { - /// An event ID (the $ character followed by the base64 reference hash) - event_id: Box<EventId>, - }, - - #[clap(verbatim_doc_comment)] - /// Parse and print a PDU from a JSON - /// - /// The PDU event is only checked for validity and is not added to the - /// database. - /// - /// [commandbody] - /// # ``` - /// # PDU json content here - /// # ``` - ParsePdu, - - /// Retrieve and print a PDU by ID from the Conduit database - GetPdu { - /// An event ID (a $ followed by the base64 reference hash) - event_id: Box<EventId>, - }, - - /// Print database memory usage statistics - DatabaseMemoryUsage, - - /// Show configuration values - ShowConfig, - - /// Reset user password - ResetPassword { - /// Username of the user for whom the password should be reset - username: String, - }, - - /// Create a new user - CreateUser { - /// Username of the new user - username: String, - /// Password of the new user, if unspecified one is generated - password: Option<String>, - }, - - /// Disables incoming federation handling for a room. - DisableRoom { room_id: Box<RoomId> }, - /// Enables incoming federation handling for a room again. - EnableRoom { room_id: Box<RoomId> }, -} - -async fn process_admin_command( - db: &Database, - command: AdminCommand, - body: Vec<&str>, -) -> Result<RoomMessageEventContent> { - let reply_message_content = match command { - AdminCommand::RegisterAppservice => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { - let appservice_config = body[1..body.len() - 1].join("\n"); - let parsed_config = serde_yaml::from_str::<serde_yaml::Value>(&appservice_config); - match parsed_config { - Ok(yaml) => match db.appservice.register_appservice(yaml) { - Ok(id) => RoomMessageEventContent::text_plain(format!( - "Appservice registered with ID: {}.", - id - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to register appservice: {}", - e - )), - }, - Err(e) => RoomMessageEventContent::text_plain(format!( - "Could not parse appservice config: {}", - e - )), - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - AdminCommand::UnregisterAppservice { - appservice_identifier, - } => match db.appservice.unregister_appservice(&appservice_identifier) { - Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to unregister appservice: {}", - e - )), - }, - AdminCommand::ListAppservices => { - if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::<Vec<_>>()) { - let count = appservices.len(); - let output = format!( - "Appservices ({}): {}", - count, - appservices - .into_iter() - .filter_map(|r| r.ok()) - .collect::<Vec<_>>() - .join(", ") - ); - RoomMessageEventContent::text_plain(output) - } else { - RoomMessageEventContent::text_plain("Failed to get appservices.") - } - } - AdminCommand::ListRooms => { - let room_ids = db.rooms.iter_ids(); - let output = format!( - "Rooms:\n{}", - room_ids - .filter_map(|r| r.ok()) - .map(|id| id.to_string() - + "\tMembers: " - + &db - .rooms - .room_joined_count(&id) - .ok() - .flatten() - .unwrap_or(0) - .to_string()) - .collect::<Vec<_>>() - .join("\n") - ); - RoomMessageEventContent::text_plain(output) - } - AdminCommand::ListLocalUsers => match db.users.list_local_users() { - Ok(users) => { - let mut msg: String = format!("Found {} local user account(s):\n", users.len()); - msg += &users.join("\n"); - RoomMessageEventContent::text_plain(&msg) - } - Err(e) => RoomMessageEventContent::text_plain(e.to_string()), - }, - AdminCommand::IncomingFederation => { - let map = db.globals.roomid_federationhandletime.read().unwrap(); - let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); - - for (r, (e, i)) in map.iter() { - let elapsed = i.elapsed(); - msg += &format!( - "{} {}: {}m{}s\n", - r, - e, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } - RoomMessageEventContent::text_plain(&msg) - } - AdminCommand::GetAuthChain { event_id } => { - let event_id = Arc::<EventId>::from(event_id); - if let Some(event) = db.rooms.get_pdu_json(&event_id)? { - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { - Error::bad_database("Invalid room id field in event in database") - })?; - let start = Instant::now(); - let count = server_server::get_auth_chain(room_id, vec![event_id], db) - .await? - .count(); - let elapsed = start.elapsed(); - RoomMessageEventContent::text_plain(format!( - "Loaded auth chain with length {} in {:?}", - count, elapsed - )) - } else { - RoomMessageEventContent::text_plain("Event not found.") - } - } - AdminCommand::ParsePdu => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { - let string = body[1..body.len() - 1].join("\n"); - match serde_json::from_str(&string) { - Ok(value) => { - match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { - Ok(hash) => { - let event_id = EventId::parse(format!("${}", hash)); - - match serde_json::from_value::<PduEvent>( - serde_json::to_value(value).expect("value is json"), - ) { - Ok(pdu) => RoomMessageEventContent::text_plain(format!( - "EventId: {:?}\n{:#?}", - event_id, pdu - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "EventId: {:?}\nCould not parse event: {}", - event_id, e - )), - } - } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Could not parse PDU JSON: {:?}", - e - )), - } - } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Invalid json in command body: {}", - e - )), - } - } else { - RoomMessageEventContent::text_plain("Expected code block in command body.") - } - } - AdminCommand::GetPdu { event_id } => { - let mut outlier = false; - let mut pdu_json = db.rooms.get_non_outlier_pdu_json(&event_id)?; - if pdu_json.is_none() { - outlier = true; - pdu_json = db.rooms.get_pdu_json(&event_id)?; - } - match pdu_json { - Some(json) => { - let json_text = - serde_json::to_string_pretty(&json).expect("canonical json is valid json"); - RoomMessageEventContent::text_html( - format!( - "{}\n```json\n{}\n```", - if outlier { - "PDU is outlier" - } else { - "PDU was accepted" - }, - json_text - ), - format!( - "<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n", - if outlier { - "PDU is outlier" - } else { - "PDU was accepted" - }, - HtmlEscape(&json_text) - ), - ) - } - None => RoomMessageEventContent::text_plain("PDU not found."), - } - } - AdminCommand::DatabaseMemoryUsage => match db._db.memory_usage() { - Ok(response) => RoomMessageEventContent::text_plain(response), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to get database memory usage: {}", - e - )), - }, - AdminCommand::ShowConfig => { - // Construct and send the response - RoomMessageEventContent::text_plain(format!("{}", db.globals.config)) - } - AdminCommand::ResetPassword { username } => { - let user_id = match UserId::parse_with_server_name( - username.as_str().to_lowercase(), - db.globals.server_name(), - ) { - Ok(id) => id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {}", - e - ))) - } - }; - - // Check if the specified user is valid - if !db.users.exists(&user_id)? - || db.users.is_deactivated(&user_id)? - || user_id - == UserId::parse_with_server_name("conduit", db.globals.server_name()) - .expect("conduit user exists") - { - return Ok(RoomMessageEventContent::text_plain( - "The specified user does not exist or is deactivated!", - )); - } - - let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); - - match db.users.set_password(&user_id, Some(new_password.as_str())) { - Ok(()) => RoomMessageEventContent::text_plain(format!( - "Successfully reset the password for user {}: {}", - user_id, new_password - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Couldn't reset the password for user {}: {}", - user_id, e - )), - } - } - AdminCommand::CreateUser { username, password } => { - let password = password.unwrap_or(utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); - // Validate user id - let user_id = match UserId::parse_with_server_name( - username.as_str().to_lowercase(), - db.globals.server_name(), - ) { - Ok(id) => id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {}", - e - ))) - } - }; - if user_id.is_historical() { - return Ok(RoomMessageEventContent::text_plain(format!( - "userid {user_id} is not allowed due to historical" - ))); - } - if db.users.exists(&user_id)? { - return Ok(RoomMessageEventContent::text_plain(format!( - "userid {user_id} already exists" - ))); - } - // Create user - db.users.create(&user_id, Some(password.as_str()))?; - - // Default to pretty displayname - let mut displayname = user_id.localpart().to_owned(); - - // If enabled append lightning bolt to display name (default true) - if db.globals.enable_lightning_bolt() { - displayname.push_str(" ⚡️"); - } - - db.users - .set_displayname(&user_id, Some(displayname.clone()))?; - - // Initial account data - db.account_data.update( - None, - &user_id, - ruma::events::GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: ruma::push::Ruleset::server_default(&user_id), - }, - }, - &db.globals, - )?; - - // we dont add a device since we're not the user, just the creator - - db.flush()?; - - // Inhibit login does not work for guests - RoomMessageEventContent::text_plain(format!( - "Created user with user_id: {user_id} and password: {password}" - )) - } - AdminCommand::DisableRoom { room_id } => { - db.rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; - RoomMessageEventContent::text_plain("Room disabled.") - } - AdminCommand::EnableRoom { room_id } => { - db.rooms.disabledroomids.remove(room_id.as_bytes())?; - RoomMessageEventContent::text_plain("Room enabled.") - } - AdminCommand::DeactivateUser { - leave_rooms, - user_id, - } => { - let user_id = Arc::<UserId>::from(user_id); - if db.users.exists(&user_id)? { - RoomMessageEventContent::text_plain(format!( - "Making {} leave all rooms before deactivation...", - user_id - )); - - db.users.deactivate_account(&user_id)?; - - if leave_rooms { - db.rooms.leave_all_rooms(&user_id, &db).await?; - } - - RoomMessageEventContent::text_plain(format!( - "User {} has been deactivated", - user_id - )) - } else { - RoomMessageEventContent::text_plain(format!( - "User {} doesn't exist on this server", - user_id - )) - } - } - AdminCommand::DeactivateAll { leave_rooms, force } => { - if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" { - let usernames = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>(); - - let mut user_ids: Vec<&UserId> = Vec::new(); - - for &username in &usernames { - match <&UserId>::try_from(username) { - Ok(user_id) => user_ids.push(user_id), - Err(_) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "{} is not a valid username", - username - ))) - } - } - } - - let mut deactivation_count = 0; - let mut admins = Vec::new(); - - if !force { - user_ids.retain(|&user_id| { - match db.users.is_admin(user_id, &db.rooms, &db.globals) { - Ok(is_admin) => match is_admin { - true => { - admins.push(user_id.localpart()); - false - } - false => true, - }, - Err(_) => false, - } - }) - } - - for &user_id in &user_ids { - match db.users.deactivate_account(user_id) { - Ok(_) => deactivation_count += 1, - Err(_) => {} - } - } - - if leave_rooms { - for &user_id in &user_ids { - let _ = db.rooms.leave_all_rooms(user_id, &db).await; - } - } - - if admins.is_empty() { - RoomMessageEventContent::text_plain(format!( - "Deactivated {} accounts.", - deactivation_count - )) - } else { - RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", "))) - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - }; - - Ok(reply_message_content) -} - -// Utility to turn clap's `--help` text to HTML. -fn usage_to_html(text: &str, server_name: &ServerName) -> String { - // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: subcmdname` - let text = text.replace( - &format!("@conduit:{}:-", server_name), - &format!("@conduit:{}: ", server_name), - ); - - // For the conduit admin room, subcommands become main commands - let text = text.replace("SUBCOMMAND", "COMMAND"); - let text = text.replace("subcommand", "command"); - - // Escape option names (e.g. `<element-id>`) since they look like HTML tags - let text = text.replace("<", "<").replace(">", ">"); - - // Italicize the first line (command name and version text) - let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "<em>$1</em>\n"); - - // Unmerge wrapped lines - let text = text.replace("\n ", " "); - - // Wrap option names in backticks. The lines look like: - // -V, --version Prints version information - // And are converted to: - // <code>-V, --version</code>: Prints version information - // (?m) enables multi-line mode for ^ and $ - let re = Regex::new("(?m)^ (([a-zA-Z_&;-]+(, )?)+) +(.*)$") - .expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "<code>$1</code>: $4"); - - // Look for a `[commandbody]` tag. If it exists, use all lines below it that - // start with a `#` in the USAGE section. - let mut text_lines: Vec<&str> = text.lines().collect(); - let mut command_body = String::new(); - - if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { - text_lines.remove(line_index); - - while text_lines - .get(line_index) - .map(|line| line.starts_with("#")) - .unwrap_or(false) - { - command_body += if text_lines[line_index].starts_with("# ") { - &text_lines[line_index][2..] - } else { - &text_lines[line_index][1..] - }; - command_body += "[nobr]\n"; - text_lines.remove(line_index); - } - } - - let text = text_lines.join("\n"); - - // Improve the usage section - let text = if command_body.is_empty() { - // Wrap the usage line in code tags - let re = Regex::new("(?m)^USAGE:\n (@conduit:.*)$") - .expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n<code>$1</code>").to_string() - } else { - // Wrap the usage line in a code block, and add a yaml block example - // This makes the usage of e.g. `register-appservice` more accurate - let re = - Regex::new("(?m)^USAGE:\n (.*?)\n\n").expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>") - .replace("[commandbodyblock]", &command_body) - }; - - // Add HTML line-breaks - let text = text - .replace("\n\n\n", "\n\n") - .replace("\n", "<br>\n") - .replace("[nobr]<br>", ""); - - text -} - -/// Create the admin room. -/// -/// Users in this room are considered admins by conduit, and the room can be -/// used to issue admin commands by talking to the server user inside it. -pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { - let room_id = RoomId::new(db.globals.server_name()); - - db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Create a user for the server - let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) - .expect("@conduit:server_name is valid"); - - db.users.create(&conduit_user, None)?; - - let mut content = RoomCreateEventContent::new(conduit_user.clone()); - content.federate = true; - content.predecessor = None; - content.room_version = RoomVersionId::V6; - - // 1. The room create event - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // 2. Make conduit bot join - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(conduit_user.to_string()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // 3. Power levels - let mut users = BTreeMap::new(); - users.insert(conduit_user.clone(), 100.into()); - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // 4.1 Join Rules - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // 4.2 History Visibility - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new( - HistoryVisibility::Shared, - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // 4.3 Guest Access - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // 5. Events implied by name and topic - let room_name = RoomName::parse(format!("{} Admin Room", db.globals.server_name())) - .expect("Room name is valid"); - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(Some(room_name))) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { - topic: format!("Manage {}", db.globals.server_name()), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // 6. Room alias - let alias: Box<RoomAliasId> = format!("#admins:{}", db.globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { - alias: Some(alias.clone()), - alt_aliases: Vec::new(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; - - Ok(()) -} - -/// Invite the user to the conduit admin room. -/// -/// In conduit, this is equivalent to granting admin privileges. -pub(crate) async fn make_user_admin( - db: &Database, - user_id: &UserId, - displayname: String, -) -> Result<()> { - let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", db.globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - let room_id = db - .rooms - .id_from_alias(&admin_room_alias)? - .expect("Admin room must exist"); - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Use the server user to grant the new admin's power level - let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) - .expect("@conduit:server_name is valid"); - - // Invite and join the real user - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: Some(displayname), - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - &user_id, - &room_id, - &db, - &state_lock, - )?; - - // Set power level - let mut users = BTreeMap::new(); - users.insert(conduit_user.to_owned(), 100.into()); - users.insert(user_id.to_owned(), 100.into()); - - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - // Send welcome message - db.rooms.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMessage, - content: to_raw_value(&RoomMessageEventContent::text_html( - format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", db.globals.server_name()).to_owned(), - format!("<h2>Thank you for trying out Conduit!</h2>\n<p>Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Website: https://conduit.rs<br>Git and Documentation: https://gitlab.com/famedly/conduit<br>Report issues: https://gitlab.com/famedly/conduit/-/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>Conduit room (Ask questions and get notified on updates):<br><code>/join #conduit:fachschaften.org</code></p>\n<p>Conduit lounge (Off-topic, only Conduit users are allowed to join)<br><code>/join #conduit-lounge:conduit.rs</code></p>\n", db.globals.server_name()).to_owned(), - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &room_id, - &db, - &state_lock, - )?; - - Ok(()) -} diff --git a/src/database/account_data.rs b/src/database/key_value/account_data.rs index d85918f..e1eef96 100644 --- a/src/database/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,30 +1,23 @@ -use crate::{utils, Error, Result}; +use std::collections::HashMap; + use ruma::{ api::client::error::ErrorKind, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; -use serde::{de::DeserializeOwned, Serialize}; -use std::{collections::HashMap, sync::Arc}; - -use super::abstraction::Tree; -pub struct AccountData { - pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type - pub(super) roomusertype_roomuserdataid: Arc<dyn Tree>, // RoomUserType = Room + User + Type -} +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; -impl AccountData { +impl service::account_data::Data for KeyValueDatabase { /// Places one event in the account data of the user and removes the previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data, globals))] - pub fn update<T: Serialize>( + #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] + fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &T, - globals: &super::globals::Globals, + data: &serde_json::Value, ) -> Result<()> { let mut prefix = room_id .map(|r| r.to_string()) @@ -36,15 +29,14 @@ impl AccountData { prefix.push(0xff); let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&globals.next_count()?.to_be_bytes()); + roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); roomuserdataid.push(0xff); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); let mut key = prefix; key.extend_from_slice(event_type.to_string().as_bytes()); - let json = serde_json::to_value(data).expect("all types here can be serialized"); // TODO: maybe add error handling - if json.get("type").is_none() || json.get("content").is_none() { + if data.get("type").is_none() || data.get("content").is_none() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Account data doesn't have all required fields.", @@ -53,7 +45,7 @@ impl AccountData { self.roomuserdataid_accountdata.insert( &roomuserdataid, - &serde_json::to_vec(&json).expect("to_vec always works on json values"), + &serde_json::to_vec(&data).expect("to_vec always works on json values"), )?; let prev = self.roomusertype_roomuserdataid.get(&key)?; @@ -71,12 +63,12 @@ impl AccountData { /// Searches the account data for a specific kind. #[tracing::instrument(skip(self, room_id, user_id, kind))] - pub fn get<T: DeserializeOwned>( + fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, - ) -> Result<Option<T>> { + ) -> Result<Option<Box<serde_json::value::RawValue>>> { let mut key = room_id .map(|r| r.to_string()) .unwrap_or_default() @@ -104,7 +96,7 @@ impl AccountData { /// Returns all changes to the account data that happened after `since`. #[tracing::instrument(skip(self, room_id, user_id, since))] - pub fn changes_since( + fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, diff --git a/src/database/appservice.rs b/src/database/key_value/appservice.rs index edd5009..9a821a6 100644 --- a/src/database/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,20 +1,8 @@ -use crate::{utils, Error, Result}; -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; +use crate::{database::KeyValueDatabase, service, utils, Error, Result}; -use super::abstraction::Tree; - -pub struct Appservice { - pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>, - pub(super) id_appserviceregistrations: Arc<dyn Tree>, -} - -impl Appservice { +impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller - /// - pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> { + fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> { // TODO: Rumaify let id = yaml.get("id").unwrap().as_str().unwrap(); self.id_appserviceregistrations.insert( @@ -34,7 +22,7 @@ impl Appservice { /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations .remove(service_name.as_bytes())?; self.cached_registrations @@ -44,7 +32,7 @@ impl Appservice { Ok(()) } - pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> { + fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> { self.cached_registrations .read() .unwrap() @@ -66,14 +54,17 @@ impl Appservice { ) } - pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { - Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { - utils::string_from_bytes(&id) - .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) - })) + fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> { + Ok(Box::new(self.id_appserviceregistrations.iter().map( + |(id, _)| { + utils::string_from_bytes(&id).map_err(|_| { + Error::bad_database("Invalid id bytes in id_appserviceregistrations.") + }) + }, + ))) } - pub fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> { + fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> { self.iter_ids()? .filter_map(|id| id.ok()) .map(move |id| { diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs new file mode 100644 index 0000000..7b7675c --- /dev/null +++ b/src/database/key_value/globals.rs @@ -0,0 +1,233 @@ +use std::collections::BTreeMap; + +use async_trait::async_trait; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use ruma::{ + api::federation::discovery::{ServerSigningKeys, VerifyKey}, + signatures::Ed25519KeyPair, + DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, +}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +pub const COUNTER: &[u8] = b"c"; + +#[async_trait] +impl service::globals::Data for KeyValueDatabase { + fn next_count(&self) -> Result<u64> { + utils::u64_from_bytes(&self.global.increment(COUNTER)?) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + } + + fn current_count(&self) -> Result<u64> { + self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + }) + } + + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let userid_bytes = user_id.as_bytes().to_vec(); + let mut userid_prefix = userid_bytes.clone(); + userid_prefix.push(0xff); + + let mut userdeviceid_prefix = userid_prefix.clone(); + userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); + userdeviceid_prefix.push(0xff); + + let mut futures = FuturesUnordered::new(); + + // Return when *any* user changed his key + // TODO: only send for user they share a room with + futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); + + futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push( + self.userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); + futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); + + // Events for rooms we are in + for room_id in services() + .rooms + .state_cache + .rooms_joined(user_id) + .filter_map(|r| r.ok()) + { + let short_roomid = services() + .rooms + .short + .get_shortroomid(&room_id) + .ok() + .flatten() + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let roomid_bytes = room_id.as_bytes().to_vec(); + let mut roomid_prefix = roomid_bytes.clone(); + roomid_prefix.push(0xff); + + // PDUs + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes)); + + futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); + + // Key changes + futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); + + // Room account data + let mut roomuser_prefix = roomid_prefix.clone(); + roomuser_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&roomuser_prefix), + ); + } + + let mut globaluserdata_prefix = vec![0xff]; + globaluserdata_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&globaluserdata_prefix), + ); + + // More key changes (used when user is not joined to any rooms) + futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); + + // One time keys + futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); + + futures.push(Box::pin(services().globals.rotate.watch())); + + // Wait until one of them finds something + futures.next().await; + + Ok(()) + } + + fn cleanup(&self) -> Result<()> { + self._db.cleanup() + } + + fn memory_usage(&self) -> Result<String> { + self._db.memory_usage() + } + + fn load_keypair(&self) -> Result<Ed25519KeyPair> { + let keypair_bytes = self.global.get(b"keypair")?.map_or_else( + || { + let keypair = utils::generate_keypair(); + self.global.insert(b"keypair", &keypair)?; + Ok::<_, Error>(keypair) + }, + |s| Ok(s.to_vec()), + )?; + + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); + + utils::string_from_bytes( + // 1. version + parts + .next() + .expect("splitn always returns at least one element"), + ) + .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) + .and_then(|version| { + // 2. key + parts + .next() + .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) + .map(|key| (version, key)) + }) + .and_then(|(version, key)| { + Ed25519KeyPair::from_der(key, version) + .map_err(|_| Error::bad_database("Private or public keys are invalid.")) + }) + } + fn remove_keypair(&self) -> Result<()> { + self.global.remove(b"keypair") + } + + fn add_signing_key( + &self, + origin: &ServerName, + new_keys: ServerSigningKeys, + ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { + // Not atomic, but this is not critical + let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + + let mut keys = signingkeys + .and_then(|keys| serde_json::from_slice(&keys).ok()) + .unwrap_or_else(|| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); + + let ServerSigningKeys { + verify_keys, + old_verify_keys, + .. + } = new_keys; + + keys.verify_keys.extend(verify_keys.into_iter()); + keys.old_verify_keys.extend(old_verify_keys.into_iter()); + + self.server_signingkeys.insert( + origin.as_bytes(), + &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), + )?; + + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + + Ok(tree) + } + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + fn signing_keys_for( + &self, + origin: &ServerName, + ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { + let signingkeys = self + .server_signingkeys + .get(origin.as_bytes())? + .and_then(|bytes| serde_json::from_slice(&bytes).ok()) + .map(|keys: ServerSigningKeys| { + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + tree + }) + .unwrap_or_else(BTreeMap::new); + + Ok(signingkeys) + } + + fn database_version(&self) -> Result<u64> { + self.global.get(b"version")?.map_or(Ok(0), |version| { + utils::u64_from_bytes(&version) + .map_err(|_| Error::bad_database("Database version id is invalid.")) + }) + } + + fn bump_database_version(&self, new_version: u64) -> Result<()> { + self.global.insert(b"version", &new_version.to_be_bytes())?; + Ok(()) + } +} diff --git a/src/database/key_backups.rs b/src/database/key_value/key_backups.rs index 10443f6..900b700 100644 --- a/src/database/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,30 +1,23 @@ -use crate::{utils, Error, Result}; +use std::collections::BTreeMap; + use ruma::{ api::client::{ backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind, }, serde::Raw, - RoomId, UserId, + OwnedRoomId, RoomId, UserId, }; -use std::{collections::BTreeMap, sync::Arc}; - -use super::abstraction::Tree; -pub struct KeyBackups { - pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count) - pub(super) backupid_etag: Arc<dyn Tree>, // BackupId = UserId + Version(Count) - pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId -} +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; -impl KeyBackups { - pub fn create_backup( +impl service::key_backups::Data for KeyValueDatabase { + fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>, - globals: &super::globals::Globals, ) -> Result<String> { - let version = globals.next_count()?.to_string(); + let version = services().globals.next_count()?.to_string(); let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -35,11 +28,11 @@ impl KeyBackups { &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag - .insert(&key, &globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; Ok(version) } - pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -56,12 +49,11 @@ impl KeyBackups { Ok(()) } - pub fn update_backup( + fn update_backup( &self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>, - globals: &super::globals::Globals, ) -> Result<String> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -77,11 +69,11 @@ impl KeyBackups { self.backupid_algorithm .insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_etag - .insert(&key, &globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; Ok(version.to_owned()) } - pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { + fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); let mut last_possible_key = prefix.clone(); @@ -102,7 +94,7 @@ impl KeyBackups { .transpose() } - pub fn get_latest_backup( + fn get_latest_backup( &self, user_id: &UserId, ) -> Result<Option<(String, Raw<BackupAlgorithm>)>> { @@ -133,11 +125,7 @@ impl KeyBackups { .transpose() } - pub fn get_backup( - &self, - user_id: &UserId, - version: &str, - ) -> Result<Option<Raw<BackupAlgorithm>>> { + fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -150,14 +138,13 @@ impl KeyBackups { }) } - pub fn add_key( + fn add_key( &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -171,7 +158,7 @@ impl KeyBackups { } self.backupid_etag - .insert(&key, &globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; key.push(0xff); key.extend_from_slice(room_id.as_bytes()); @@ -184,7 +171,7 @@ impl KeyBackups { Ok(()) } - pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { + fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); prefix.extend_from_slice(version.as_bytes()); @@ -192,7 +179,7 @@ impl KeyBackups { Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) } - pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { + fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -207,17 +194,17 @@ impl KeyBackups { .to_string()) } - pub fn get_all( + fn get_all( &self, user_id: &UserId, version: &str, - ) -> Result<BTreeMap<Box<RoomId>, RoomKeyBackup>> { + ) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); prefix.extend_from_slice(version.as_bytes()); prefix.push(0xff); - let mut rooms = BTreeMap::<Box<RoomId>, RoomKeyBackup>::new(); + let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new(); for result in self .backupkeyid_backup @@ -263,7 +250,7 @@ impl KeyBackups { Ok(rooms) } - pub fn get_room( + fn get_room( &self, user_id: &UserId, version: &str, @@ -300,7 +287,7 @@ impl KeyBackups { .collect()) } - pub fn get_session( + fn get_session( &self, user_id: &UserId, version: &str, @@ -325,7 +312,7 @@ impl KeyBackups { .transpose() } - pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -338,12 +325,7 @@ impl KeyBackups { Ok(()) } - pub fn delete_room_keys( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - ) -> Result<()> { + fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(version.as_bytes()); @@ -358,7 +340,7 @@ impl KeyBackups { Ok(()) } - pub fn delete_room_key( + fn delete_room_key( &self, user_id: &UserId, version: &str, diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs new file mode 100644 index 0000000..6abe5ba --- /dev/null +++ b/src/database/key_value/media.rs @@ -0,0 +1,82 @@ +use ruma::api::client::error::ErrorKind; + +use crate::{database::KeyValueDatabase, service, utils, Error, Result}; + +impl service::media::Data for KeyValueDatabase { + fn create_file_metadata( + &self, + mxc: String, + width: u32, + height: u32, + content_disposition: Option<&str>, + content_type: Option<&str>, + ) -> Result<Vec<u8>> { + let mut key = mxc.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(&width.to_be_bytes()); + key.extend_from_slice(&height.to_be_bytes()); + key.push(0xff); + key.extend_from_slice( + content_disposition + .as_ref() + .map(|f| f.as_bytes()) + .unwrap_or_default(), + ); + key.push(0xff); + key.extend_from_slice( + content_type + .as_ref() + .map(|c| c.as_bytes()) + .unwrap_or_default(), + ); + + self.mediaid_file.insert(&key, &[])?; + + Ok(key) + } + + fn search_file_metadata( + &self, + mxc: String, + width: u32, + height: u32, + ) -> Result<(Option<String>, Option<String>, Vec<u8>)> { + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(&width.to_be_bytes()); + prefix.extend_from_slice(&height.to_be_bytes()); + prefix.push(0xff); + + let (key, _) = self + .mediaid_file + .scan_prefix(prefix) + .next() + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; + + let mut parts = key.rsplit(|&b| b == 0xff); + + let content_type = parts + .next() + .map(|bytes| { + utils::string_from_bytes(bytes).map_err(|_| { + Error::bad_database("Content type in mediaid_file is invalid unicode.") + }) + }) + .transpose()?; + + let content_disposition_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; + + let content_disposition = if content_disposition_bytes.is_empty() { + None + } else { + Some( + utils::string_from_bytes(content_disposition_bytes).map_err(|_| { + Error::bad_database("Content Disposition in mediaid_file is invalid unicode.") + })?, + ) + }; + Ok((content_disposition, content_type, key)) + } +} diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs new file mode 100644 index 0000000..c4496af --- /dev/null +++ b/src/database/key_value/mod.rs @@ -0,0 +1,13 @@ +mod account_data; +//mod admin; +mod appservice; +mod globals; +mod key_backups; +mod media; +//mod pdu; +mod pusher; +mod rooms; +mod sending; +mod transaction_ids; +mod uiaa; +mod users; diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs new file mode 100644 index 0000000..50a6fac --- /dev/null +++ b/src/database/key_value/pusher.rs @@ -0,0 +1,79 @@ +use ruma::{ + api::client::push::{set_pusher, Pusher}, + UserId, +}; + +use crate::{database::KeyValueDatabase, service, utils, Error, Result}; + +impl service::pusher::Data for KeyValueDatabase { + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { + match &pusher { + set_pusher::v3::PusherAction::Post(data) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); + self.senderkey_pusher.insert( + &key, + &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), + )?; + Ok(()) + } + set_pusher::v3::PusherAction::Delete(ids) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(ids.pushkey.as_bytes()); + self.senderkey_pusher + .remove(&key) + .map(|_| ()) + .map_err(Into::into) + } + } + } + + fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { + let mut senderkey = sender.as_bytes().to_vec(); + senderkey.push(0xff); + senderkey.extend_from_slice(pushkey.as_bytes()); + + self.senderkey_pusher + .get(&senderkey)? + .map(|push| { + serde_json::from_slice(&push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .transpose() + } + + fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + self.senderkey_pusher + .scan_prefix(prefix) + .map(|(_, push)| { + serde_json::from_slice(&push) + .map_err(|_| Error::bad_database("Invalid Pusher in db.")) + }) + .collect() + } + + fn get_pushkeys<'a>( + &'a self, + sender: &UserId, + ) -> Box<dyn Iterator<Item = Result<String>> + 'a> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { + let mut parts = k.splitn(2, |&b| b == 0xff); + let _senderkey = parts.next(); + let push_key = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; + let push_key_string = utils::string_from_bytes(push_key) + .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; + + Ok(push_key_string) + })) + } +} diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs new file mode 100644 index 0000000..6f23032 --- /dev/null +++ b/src/database/key_value/rooms/alias.rs @@ -0,0 +1,60 @@ +use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::alias::Data for KeyValueDatabase { + fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { + self.alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes())?; + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xff); + aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; + Ok(()) + } + + fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { + if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { + let mut prefix = room_id.to_vec(); + prefix.push(0xff); + + for (key, _) in self.aliasid_alias.scan_prefix(prefix) { + self.aliasid_alias.remove(&key)?; + } + self.alias_roomid.remove(alias.alias().as_bytes())?; + } else { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Alias does not exist.", + )); + } + Ok(()) + } + + fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { + self.alias_roomid + .get(alias.alias().as_bytes())? + .map(|bytes| { + RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Room ID in alias_roomid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) + }) + .transpose() + } + + fn local_aliases_for_room<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) + })) + } +} diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs new file mode 100644 index 0000000..60057ac --- /dev/null +++ b/src/database/key_value/rooms/auth_chain.rs @@ -0,0 +1,61 @@ +use std::{collections::HashSet, mem::size_of, sync::Arc}; + +use crate::{database::KeyValueDatabase, service, utils, Result}; + +impl service::rooms::auth_chain::Data for KeyValueDatabase { + fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { + // Check RAM cache + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + return Ok(Some(Arc::clone(result))); + } + + // We only save auth chains for single events in the db + if key.len() == 1 { + // Check DB cache + let chain = self + .shorteventid_authchain + .get(&key[0].to_be_bytes())? + .map(|chain| { + chain + .chunks_exact(size_of::<u64>()) + .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) + .collect() + }); + + if let Some(chain) = chain { + let chain = Arc::new(chain); + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(vec![key[0]], Arc::clone(&chain)); + + return Ok(Some(chain)); + } + } + + Ok(None) + } + + fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> { + // Only persist single events in db + if key.len() == 1 { + self.shorteventid_authchain.insert( + &key[0].to_be_bytes(), + &auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::<Vec<u8>>(), + )?; + } + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(key, auth_chain); + + Ok(()) + } +} diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs new file mode 100644 index 0000000..e05dee8 --- /dev/null +++ b/src/database/key_value/rooms/directory.rs @@ -0,0 +1,28 @@ +use ruma::{OwnedRoomId, RoomId}; + +use crate::{database::KeyValueDatabase, service, utils, Error, Result}; + +impl service::rooms::directory::Data for KeyValueDatabase { + fn set_public(&self, room_id: &RoomId) -> Result<()> { + self.publicroomids.insert(room_id.as_bytes(), &[]) + } + + fn set_not_public(&self, room_id: &RoomId) -> Result<()> { + self.publicroomids.remove(room_id.as_bytes()) + } + + fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { + Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) + } + + fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { + Box::new(self.publicroomids.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Room ID in publicroomids is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) + })) + } +} diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs new file mode 100644 index 0000000..6c65291 --- /dev/null +++ b/src/database/key_value/rooms/edus/mod.rs @@ -0,0 +1,7 @@ +mod presence; +mod read_receipt; +mod typing; + +use crate::{database::KeyValueDatabase, service}; + +impl service::rooms::edus::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs new file mode 100644 index 0000000..904b1c4 --- /dev/null +++ b/src/database/key_value/rooms/edus/presence.rs @@ -0,0 +1,152 @@ +use std::collections::HashMap; + +use ruma::{ + events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId, +}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::edus::presence::Data for KeyValueDatabase { + fn update_presence( + &self, + user_id: &UserId, + room_id: &RoomId, + presence: PresenceEvent, + ) -> Result<()> { + // TODO: Remove old entry? Or maybe just wipe completely from time to time? + + let count = services().globals.next_count()?.to_be_bytes(); + + let mut presence_id = room_id.as_bytes().to_vec(); + presence_id.push(0xff); + presence_id.extend_from_slice(&count); + presence_id.push(0xff); + presence_id.extend_from_slice(presence.sender.as_bytes()); + + self.presenceid_presence.insert( + &presence_id, + &serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"), + )?; + + self.userid_lastpresenceupdate.insert( + user_id.as_bytes(), + &utils::millis_since_unix_epoch().to_be_bytes(), + )?; + + Ok(()) + } + + fn ping_presence(&self, user_id: &UserId) -> Result<()> { + self.userid_lastpresenceupdate.insert( + user_id.as_bytes(), + &utils::millis_since_unix_epoch().to_be_bytes(), + )?; + + Ok(()) + } + + fn last_presence_update(&self, user_id: &UserId) -> Result<Option<u64>> { + self.userid_lastpresenceupdate + .get(user_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.") + }) + }) + .transpose() + } + + fn get_presence_event( + &self, + room_id: &RoomId, + user_id: &UserId, + count: u64, + ) -> Result<Option<PresenceEvent>> { + let mut presence_id = room_id.as_bytes().to_vec(); + presence_id.push(0xff); + presence_id.extend_from_slice(&count.to_be_bytes()); + presence_id.push(0xff); + presence_id.extend_from_slice(user_id.as_bytes()); + + self.presenceid_presence + .get(&presence_id)? + .map(|value| parse_presence_event(&value)) + .transpose() + } + + fn presence_since( + &self, + room_id: &RoomId, + since: u64, + ) -> Result<HashMap<OwnedUserId, PresenceEvent>> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let mut first_possible_edu = prefix.clone(); + first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since + let mut hashmap = HashMap::new(); + + for (key, value) in self + .presenceid_presence + .iter_from(&first_possible_edu, false) + .take_while(|(key, _)| key.starts_with(&prefix)) + { + let user_id = UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Invalid UserId bytes in presenceid_presence."))?, + ) + .map_err(|_| Error::bad_database("Invalid UserId in presenceid_presence."))?; + + let presence = parse_presence_event(&value)?; + + hashmap.insert(user_id, presence); + } + + Ok(hashmap) + } + + /* + fn presence_maintain(&self, db: Arc<TokioRwLock<Database>>) { + // TODO @M0dEx: move this to a timed tasks module + tokio::spawn(async move { + loop { + select! { + Some(user_id) = self.presence_timers.next() { + // TODO @M0dEx: would it be better to acquire the lock outside the loop? + let guard = db.read().await; + + // TODO @M0dEx: add self.presence_timers + // TODO @M0dEx: maintain presence + } + } + } + }); + } + */ +} + +fn parse_presence_event(bytes: &[u8]) -> Result<PresenceEvent> { + let mut presence: PresenceEvent = serde_json::from_slice(bytes) + .map_err(|_| Error::bad_database("Invalid presence event in db."))?; + + let current_timestamp: UInt = utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"); + + if presence.content.presence == PresenceState::Online { + // Don't set last_active_ago when the user is online + presence.content.last_active_ago = None; + } else { + // Convert from timestamp to duration + presence.content.last_active_ago = presence + .content + .last_active_ago + .map(|timestamp| current_timestamp - timestamp); + } + + Ok(presence) +} diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs new file mode 100644 index 0000000..fa97ea3 --- /dev/null +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -0,0 +1,150 @@ +use std::mem; + +use ruma::{ + events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, +}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { + fn readreceipt_update( + &self, + user_id: &UserId, + room_id: &RoomId, + event: ReceiptEvent, + ) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + // Remove old entry + if let Some((old, _)) = self + .readreceiptid_readreceipt + .iter_from(&last_possible_key, true) + .take_while(|(key, _)| key.starts_with(&prefix)) + .find(|(key, _)| { + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element") + == user_id.as_bytes() + }) + { + // This is the old room_latest + self.readreceiptid_readreceipt.remove(&old)?; + } + + let mut room_latest_id = prefix; + room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.push(0xff); + room_latest_id.extend_from_slice(user_id.as_bytes()); + + self.readreceiptid_readreceipt.insert( + &room_latest_id, + &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), + )?; + + Ok(()) + } + + fn readreceipts_since<'a>( + &'a self, + room_id: &RoomId, + since: u64, + ) -> Box< + dyn Iterator< + Item = Result<( + OwnedUserId, + u64, + Raw<ruma::events::AnySyncEphemeralRoomEvent>, + )>, + > + 'a, + > { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + let prefix2 = prefix.clone(); + + let mut first_possible_edu = prefix.clone(); + first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since + + Box::new( + self.readreceiptid_readreceipt + .iter_from(&first_possible_edu, false) + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count = utils::u64_from_bytes( + &k[prefix.len()..prefix.len() + mem::size_of::<u64>()], + ) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + let user_id = UserId::parse( + utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..]) + .map_err(|_| { + Error::bad_database("Invalid readreceiptid userid bytes in db.") + })?, + ) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; + + let mut json = + serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| { + Error::bad_database( + "Read receipt in roomlatestid_roomlatest is invalid json.", + ) + })?; + json.remove("room_id"); + + Ok(( + user_id, + count, + Raw::from_json( + serde_json::value::to_raw_value(&json) + .expect("json is valid raw value"), + ), + )) + }), + ) + } + + fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_privateread + .insert(&key, &count.to_be_bytes())?; + + self.roomuserid_lastprivatereadupdate + .insert(&key, &services().globals.next_count()?.to_be_bytes()) + } + + fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_privateread + .get(&key)? + .map_or(Ok(None), |v| { + Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { + Error::bad_database("Invalid private read marker bytes") + })?)) + }) + } + + fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + Ok(self + .roomuserid_lastprivatereadupdate + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") + }) + }) + .transpose()? + .unwrap_or(0)) + } +} diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs new file mode 100644 index 0000000..5709192 --- /dev/null +++ b/src/database/key_value/rooms/edus/typing.rs @@ -0,0 +1,127 @@ +use std::{collections::HashSet, mem}; + +use ruma::{OwnedUserId, RoomId, UserId}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::edus::typing::Data for KeyValueDatabase { + fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let count = services().globals.next_count()?.to_be_bytes(); + + let mut room_typing_id = prefix; + room_typing_id.extend_from_slice(&timeout.to_be_bytes()); + room_typing_id.push(0xff); + room_typing_id.extend_from_slice(&count); + + self.typingid_userid + .insert(&room_typing_id, user_id.as_bytes())?; + + self.roomid_lasttypingupdate + .insert(room_id.as_bytes(), &count)?; + + Ok(()) + } + + fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let user_id = user_id.to_string(); + + let mut found_outdated = false; + + // Maybe there are multiple ones from calling roomtyping_add multiple times + for outdated_edu in self + .typingid_userid + .scan_prefix(prefix) + .filter(|(_, v)| &**v == user_id.as_bytes()) + { + self.typingid_userid.remove(&outdated_edu.0)?; + found_outdated = true; + } + + if found_outdated { + self.roomid_lasttypingupdate.insert( + room_id.as_bytes(), + &services().globals.next_count()?.to_be_bytes(), + )?; + } + + Ok(()) + } + + fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let current_timestamp = utils::millis_since_unix_epoch(); + + let mut found_outdated = false; + + // Find all outdated edus before inserting a new one + for outdated_edu in self + .typingid_userid + .scan_prefix(prefix) + .map(|(key, _)| { + Ok::<_, Error>(( + key.clone(), + utils::u64_from_bytes( + &key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| { + Error::bad_database("RoomTyping has invalid timestamp or delimiters.") + })?[0..mem::size_of::<u64>()], + ) + .map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?, + )) + }) + .filter_map(|r| r.ok()) + .take_while(|&(_, timestamp)| timestamp < current_timestamp) + { + // This is an outdated edu (time > timestamp) + self.typingid_userid.remove(&outdated_edu.0)?; + found_outdated = true; + } + + if found_outdated { + self.roomid_lasttypingupdate.insert( + room_id.as_bytes(), + &services().globals.next_count()?.to_be_bytes(), + )?; + } + + Ok(()) + } + + fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> { + Ok(self + .roomid_lasttypingupdate + .get(room_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") + }) + }) + .transpose()? + .unwrap_or(0)) + } + + fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let mut user_ids = HashSet::new(); + + for (_, user_id) in self.typingid_userid.scan_prefix(prefix) { + let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| { + Error::bad_database("User ID in typingid_userid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?; + + user_ids.insert(user_id); + } + + Ok(user_ids) + } +} diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs new file mode 100644 index 0000000..a19d52c --- /dev/null +++ b/src/database/key_value/rooms/lazy_load.rs @@ -0,0 +1,65 @@ +use ruma::{DeviceId, RoomId, UserId}; + +use crate::{database::KeyValueDatabase, service, Result}; + +impl service::rooms::lazy_loading::Data for KeyValueDatabase { + fn lazy_load_was_sent_before( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ll_user: &UserId, + ) -> Result<bool> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xff); + key.extend_from_slice(ll_user.as_bytes()); + Ok(self.lazyloadedids.get(&key)?.is_some()) + } + + fn lazy_load_confirm_delivery( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + confirmed_user_ids: &mut dyn Iterator<Item = &UserId>, + ) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xff); + + for ll_id in confirmed_user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(ll_id.as_bytes()); + self.lazyloadedids.insert(&key, &[])?; + } + + Ok(()) + } + + fn lazy_load_reset( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xff); + + for (key, _) in self.lazyloadedids.scan_prefix(prefix) { + self.lazyloadedids.remove(&key)?; + } + + Ok(()) + } +} diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs new file mode 100644 index 0000000..57540c4 --- /dev/null +++ b/src/database/key_value/rooms/metadata.rs @@ -0,0 +1,45 @@ +use ruma::{OwnedRoomId, RoomId}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::metadata::Data for KeyValueDatabase { + fn exists(&self, room_id: &RoomId) -> Result<bool> { + let prefix = match services().rooms.short.get_shortroomid(room_id)? { + Some(b) => b.to_be_bytes().to_vec(), + None => return Ok(false), + }; + + // Look for PDUs in that room. + Ok(self + .pduid_pdu + .iter_from(&prefix, false) + .next() + .filter(|(k, _)| k.starts_with(&prefix)) + .is_some()) + } + + fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { + Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Room ID in publicroomids is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) + })) + } + + fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { + Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) + } + + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + if disabled { + self.disabledroomids.insert(room_id.as_bytes(), &[])?; + } else { + self.disabledroomids.remove(room_id.as_bytes())?; + } + + Ok(()) + } +} diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs new file mode 100644 index 0000000..406943e --- /dev/null +++ b/src/database/key_value/rooms/mod.rs @@ -0,0 +1,20 @@ +mod alias; +mod auth_chain; +mod directory; +mod edus; +mod lazy_load; +mod metadata; +mod outlier; +mod pdu_metadata; +mod search; +mod short; +mod state; +mod state_accessor; +mod state_cache; +mod state_compressor; +mod timeline; +mod user; + +use crate::{database::KeyValueDatabase, service}; + +impl service::rooms::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs new file mode 100644 index 0000000..7985ba8 --- /dev/null +++ b/src/database/key_value/rooms/outlier.rs @@ -0,0 +1,28 @@ +use ruma::{CanonicalJsonObject, EventId}; + +use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result}; + +impl service::rooms::outlier::Data for KeyValueDatabase { + fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + } + + fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + } + + fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + self.eventid_outlierpdu.insert( + event_id.as_bytes(), + &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), + ) + } +} diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs new file mode 100644 index 0000000..76ec734 --- /dev/null +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -0,0 +1,33 @@ +use std::sync::Arc; + +use ruma::{EventId, RoomId}; + +use crate::{database::KeyValueDatabase, service, Result}; + +impl service::rooms::pdu_metadata::Data for KeyValueDatabase { + fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { + for prev in event_ids { + let mut key = room_id.as_bytes().to_vec(); + key.extend_from_slice(prev.as_bytes()); + self.referencedevents.insert(&key, &[])?; + } + + Ok(()) + } + + fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { + let mut key = room_id.as_bytes().to_vec(); + key.extend_from_slice(event_id.as_bytes()); + Ok(self.referencedevents.get(&key)?.is_some()) + } + + fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { + self.softfailedeventids.insert(event_id.as_bytes(), &[]) + } + + fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { + self.softfailedeventids + .get(event_id.as_bytes()) + .map(|o| o.is_some()) + } +} diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs new file mode 100644 index 0000000..19ae57b --- /dev/null +++ b/src/database/key_value/rooms/search.rs @@ -0,0 +1,75 @@ +use std::mem::size_of; + +use ruma::RoomId; + +use crate::{database::KeyValueDatabase, service, services, utils, Result}; + +impl service::rooms::search::Data for KeyValueDatabase { + fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + let mut batch = message_body + .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .filter(|word| word.len() <= 50) + .map(str::to_lowercase) + .map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xff); + key.extend_from_slice(pdu_id); + (key, Vec::new()) + }); + + self.tokenids.insert_batch(&mut batch) + } + + fn search_pdus<'a>( + &'a self, + room_id: &RoomId, + search_string: &str, + ) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>> { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + let prefix_clone = prefix.clone(); + + let words: Vec<_> = search_string + .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .map(str::to_lowercase) + .collect(); + + let iterators = words.clone().into_iter().map(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xff); + + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.tokenids + .iter_from(&last_possible_id, true) // Newest pdus first + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(|(key, _)| key[key.len() - size_of::<u64>()..].to_vec()) + }); + + let common_elements = match utils::common_elements(iterators, |a, b| { + // We compare b with a because we reversed the iterator earlier + b.cmp(a) + }) { + Some(it) => it, + None => return Ok(None), + }; + + let mapped = common_elements.map(move |id| { + let mut pduid = prefix_clone.clone(); + pduid.extend_from_slice(&id); + pduid + }); + + Ok(Some((Box::new(mapped), words))) + } +} diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs new file mode 100644 index 0000000..c022317 --- /dev/null +++ b/src/database/key_value/rooms/short.rs @@ -0,0 +1,218 @@ +use std::sync::Arc; + +use ruma::{events::StateEventType, EventId, RoomId}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::short::Data for KeyValueDatabase { + fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { + if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { + return Ok(*short); + } + + let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { + Some(shorteventid) => utils::u64_from_bytes(&shorteventid) + .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + None => { + let shorteventid = services().globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + } + }; + + self.eventidshort_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), short); + + Ok(short) + } + + fn get_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<u64>> { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(Some(*short)); + } + + let mut statekey = event_type.to_string().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(state_key.as_bytes()); + + let short = self + .statekey_shortstatekey + .get(&statekey)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose()?; + + if let Some(s) = short { + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), s); + } + + Ok(short) + } + + fn get_or_create_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<u64> { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(*short); + } + + let mut statekey = event_type.to_string().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(state_key.as_bytes()); + + let short = match self.statekey_shortstatekey.get(&statekey)? { + Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, + None => { + let shortstatekey = services().globals.next_count()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes())?; + self.shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &statekey)?; + shortstatekey + } + }; + + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), short); + + Ok(short) + } + + fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { + if let Some(id) = self + .shorteventid_cache + .lock() + .unwrap() + .get_mut(&shorteventid) + { + return Ok(Arc::clone(id)); + } + + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + + let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; + + self.shorteventid_cache + .lock() + .unwrap() + .insert(shorteventid, Arc::clone(&event_id)); + + Ok(event_id) + } + + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + if let Some(id) = self + .shortstatekey_cache + .lock() + .unwrap() + .get_mut(&shortstatekey) + { + return Ok(id.clone()); + } + + let bytes = self + .shortstatekey_statekey + .get(&shortstatekey.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; + + let mut parts = bytes.splitn(2, |&b| b == 0xff); + let eventtype_bytes = parts.next().expect("split always returns one entry"); + let statekey_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; + + let event_type = + StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| { + Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?; + + let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { + Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") + })?; + + let result = (event_type, state_key); + + self.shortstatekey_cache + .lock() + .unwrap() + .insert(shortstatekey, result.clone()); + + Ok(result) + } + + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + Ok(match self.statehash_shortstatehash.get(state_hash)? { + Some(shortstatehash) => ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ), + None => { + let shortstatehash = services().globals.next_count()?; + self.statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + } + }) + } + + fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.roomid_shortroomid + .get(room_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + }) + .transpose() + } + + fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { + Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, + None => { + let short = services().globals.next_count()?; + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes())?; + short + } + }) + } +} diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs new file mode 100644 index 0000000..f17d37b --- /dev/null +++ b/src/database/key_value/rooms/state.rs @@ -0,0 +1,73 @@ +use ruma::{EventId, OwnedEventId, RoomId}; +use std::collections::HashSet; + +use std::sync::Arc; +use tokio::sync::MutexGuard; + +use crate::{database::KeyValueDatabase, service, utils, Error, Result}; + +impl service::rooms::state::Data for KeyValueDatabase { + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.roomid_shortstatehash + .get(room_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") + })?)) + }) + } + + fn set_room_state( + &self, + room_id: &RoomId, + new_shortstatehash: u64, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.roomid_shortstatehash + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + Ok(()) + } + + fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { + self.shorteventid_shortstatehash + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + Ok(()) + } + + fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + self.roomid_pduleaves + .scan_prefix(prefix) + .map(|(_, bytes)| { + EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + }) + .collect() + } + + fn set_forward_extremities<'a>( + &self, + room_id: &RoomId, + event_ids: Vec<OwnedEventId>, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { + self.roomid_pduleaves.remove(&key)?; + } + + for event_id in event_ids { + let mut key = prefix.to_owned(); + key.extend_from_slice(event_id.as_bytes()); + self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; + } + + Ok(()) + } +} diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs new file mode 100644 index 0000000..0f0c0dc --- /dev/null +++ b/src/database/key_value/rooms/state_accessor.rs @@ -0,0 +1,186 @@ +use std::{collections::HashMap, sync::Arc}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; +use async_trait::async_trait; +use ruma::{events::StateEventType, EventId, RoomId}; + +#[async_trait] +impl service::rooms::state_accessor::Data for KeyValueDatabase { + async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state.into_iter() { + let parsed = services() + .rooms + .state_compressor + .parse_compressed_state_event(&compressed)?; + result.insert(parsed.0, parsed.1); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) + } + + async fn state_full( + &self, + shortstatehash: u64, + ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state { + let (_, eventid) = services() + .rooms + .state_compressor + .parse_compressed_state_event(&compressed)?; + if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { + result.insert( + ( + pdu.kind.to_string().into(), + pdu.state_key + .as_ref() + .ok_or_else(|| Error::bad_database("State event has no state key."))? + .clone(), + ), + pdu, + ); + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + Ok(result) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get_id( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<EventId>>> { + let shortstatekey = match services() + .rooms + .short + .get_shortstatekey(event_type, state_key)? + { + Some(s) => s, + None => return Ok(None), + }; + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| { + services() + .rooms + .state_compressor + .parse_compressed_state_event(&compressed) + .ok() + .map(|(_, id)| id) + })) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<PduEvent>>> { + self.state_get_id(shortstatehash, event_type, state_key)? + .map_or(Ok(None), |event_id| { + services().rooms.timeline.get_pdu(&event_id) + }) + } + + /// Returns the state hash for this pdu. + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { + self.eventid_shorteventid + .get(event_id.as_bytes())? + .map_or(Ok(None), |shorteventid| { + self.shorteventid_shortstatehash + .get(&shorteventid)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "Invalid shortstatehash bytes in shorteventid_shortstatehash", + ) + }) + }) + .transpose() + }) + } + + /// Returns the full room state. + async fn room_state_full( + &self, + room_id: &RoomId, + ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { + if let Some(current_shortstatehash) = + services().rooms.state.get_room_shortstatehash(room_id)? + { + self.state_full(current_shortstatehash).await + } else { + Ok(HashMap::new()) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<EventId>>> { + if let Some(current_shortstatehash) = + services().rooms.state.get_room_shortstatehash(room_id)? + { + self.state_get_id(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<PduEvent>>> { + if let Some(current_shortstatehash) = + services().rooms.state.get_room_shortstatehash(room_id)? + { + self.state_get(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } +} diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs new file mode 100644 index 0000000..d0ea0c2 --- /dev/null +++ b/src/database/key_value/rooms/state_cache.rs @@ -0,0 +1,622 @@ +use std::{collections::HashSet, sync::Arc}; + +use regex::Regex; +use ruma::{ + events::{AnyStrippedStateEvent, AnySyncStateEvent}, + serde::Raw, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, +}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::state_cache::Data for KeyValueDatabase { + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + self.roomuseroncejoinedids.insert(&userroom_id, &[]) + } + + fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_joined.insert(&userroom_id, &[])?; + self.roomuserid_joined.insert(&roomuser_id, &[])?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_invited( + &self, + user_id: &UserId, + room_id: &RoomId, + last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, + ) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_invitestate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()) + .expect("state to bytes always works"), + )?; + self.roomuserid_invitecount.insert( + &roomuser_id, + &services().globals.next_count()?.to_be_bytes(), + )?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_leftstate.insert( + &userroom_id, + &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), + )?; // TODO + self.roomuserid_leftcount.insert( + &roomuser_id, + &services().globals.next_count()?.to_be_bytes(), + )?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + + Ok(()) + } + + fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut joined_servers = HashSet::new(); + let mut real_users = HashSet::new(); + + for joined in self.room_members(room_id).filter_map(|r| r.ok()) { + joined_servers.insert(joined.server_name().to_owned()); + if joined.server_name() == services().globals.server_name() + && !services().users.is_deactivated(&joined).unwrap_or(true) + { + real_users.insert(joined); + } + joinedcount += 1; + } + + for _invited in self.room_members_invited(room_id).filter_map(|r| r.ok()) { + invitedcount += 1; + } + + self.roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; + + self.roomid_invitedcount + .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; + + self.our_real_users_cache + .write() + .unwrap() + .insert(room_id.to_owned(), Arc::new(real_users)); + + for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { + if !joined_servers.remove(&old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xff); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xff); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; + } + } + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xff); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xff); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } + + self.appservice_in_room_cache + .write() + .unwrap() + .remove(room_id); + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id))] + fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> { + let maybe = self + .our_real_users_cache + .read() + .unwrap() + .get(room_id) + .cloned(); + if let Some(users) = maybe { + Ok(users) + } else { + self.update_joined_count(room_id)?; + Ok(Arc::clone( + self.our_real_users_cache + .read() + .unwrap() + .get(room_id) + .unwrap(), + )) + } + } + + #[tracing::instrument(skip(self, room_id, appservice))] + fn appservice_in_room( + &self, + room_id: &RoomId, + appservice: &(String, serde_yaml::Value), + ) -> Result<bool> { + let maybe = self + .appservice_in_room_cache + .read() + .unwrap() + .get(room_id) + .and_then(|map| map.get(&appservice.0)) + .copied(); + + if let Some(b) = maybe { + Ok(b) + } else if let Some(namespaces) = appservice.1.get("namespaces") { + let users = namespaces + .get("users") + .and_then(|users| users.as_sequence()) + .map_or_else(Vec::new, |users| { + users + .iter() + .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) + .collect::<Vec<_>>() + }); + + let bridge_user_id = appservice + .1 + .get("sender_localpart") + .and_then(|string| string.as_str()) + .and_then(|string| { + UserId::parse_with_server_name(string, services().globals.server_name()).ok() + }); + + let in_room = bridge_user_id + .map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) + || self.room_members(room_id).any(|userid| { + userid.map_or(false, |userid| { + users.iter().any(|r| r.is_match(userid.as_str())) + }) + }); + + self.appservice_in_room_cache + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default() + .insert(appservice.0.clone(), in_room); + + Ok(in_room) + } else { + Ok(false) + } + } + + /// Makes a user forget a room. + #[tracing::instrument(skip(self))] + fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + /// Returns an iterator of all servers participating in this room. + #[tracing::instrument(skip(self))] + fn room_servers<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { + ServerName::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Server name in roomserverids is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) + })) + } + + #[tracing::instrument(skip(self))] + fn server_in_room<'a>(&'a self, server: &ServerName, room_id: &RoomId) -> Result<bool> { + let mut key = server.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + + self.serverroomids.get(&key).map(|o| o.is_some()) + } + + /// Returns an iterator of all rooms a server participates in (as far as we know). + #[tracing::instrument(skip(self))] + fn server_rooms<'a>( + &'a self, + server: &ServerName, + ) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { + let mut prefix = server.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { + RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) + })) + } + + /// Returns an iterator over all joined members of a room. + #[tracing::instrument(skip(self))] + fn room_members<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("User ID in roomuserid_joined is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) + })) + } + + #[tracing::instrument(skip(self))] + fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.roomid_joinedcount + .get(room_id.as_bytes())? + .map(|b| { + utils::u64_from_bytes(&b) + .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.roomid_invitedcount + .get(room_id.as_bytes())? + .map(|b| { + utils::u64_from_bytes(&b) + .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) + }) + .transpose() + } + + /// Returns an iterator over all User IDs who ever joined a room. + #[tracing::instrument(skip(self))] + fn room_useroncejoined<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new( + self.roomuseroncejoinedids + .scan_prefix(prefix) + .map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database( + "User ID in room_useroncejoined is invalid unicode.", + ) + })?, + ) + .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) + }), + ) + } + + /// Returns an iterator over all invited members of a room. + #[tracing::instrument(skip(self))] + fn room_members_invited<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new( + self.roomuserid_invitecount + .scan_prefix(prefix) + .map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("User ID in roomuserid_invited is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) + }), + ) + } + + #[tracing::instrument(skip(self))] + fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_invitecount + .get(&key)? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid invitecount in db.") + })?)) + }) + } + + #[tracing::instrument(skip(self))] + fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_leftcount + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid leftcount in db.")) + }) + .transpose() + } + + /// Returns an iterator over all rooms this user joined. + #[tracing::instrument(skip(self))] + fn rooms_joined<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { + Box::new( + self.userroomid_joined + .scan_prefix(user_id.as_bytes().to_vec()) + .map(|(key, _)| { + RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_joined is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) + }), + ) + } + + /// Returns an iterator over all rooms a user was invited to. + #[tracing::instrument(skip(self))] + fn rooms_invited<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new( + self.userroomid_invitestate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_invited is invalid unicode.") + })?, + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_invited is invalid.") + })?; + + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database("Invalid state in userroomid_invitestate.") + })?; + + Ok((room_id, state)) + }), + ) + } + + #[tracing::instrument(skip(self))] + fn invite_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_invitestate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + + Ok(state) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn left_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_leftstate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + + Ok(state) + }) + .transpose() + } + + /// Returns an iterator over all rooms a user left. + #[tracing::instrument(skip(self))] + fn rooms_left<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + Box::new( + self.userroomid_leftstate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_invited is invalid unicode.") + })?, + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_invited is invalid.") + })?; + + let state = serde_json::from_slice(&state).map_err(|_| { + Error::bad_database("Invalid state in userroomid_leftstate.") + })?; + + Ok((room_id, state)) + }), + ) + } + + #[tracing::instrument(skip(self))] + fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) + } +} diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs new file mode 100644 index 0000000..d0a9be4 --- /dev/null +++ b/src/database/key_value/rooms/state_compressor.rs @@ -0,0 +1,61 @@ +use std::{collections::HashSet, mem::size_of}; + +use crate::{ + database::KeyValueDatabase, + service::{self, rooms::state_compressor::data::StateDiff}, + utils, Error, Result, +}; + +impl service::rooms::state_compressor::Data for KeyValueDatabase { + fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { + let value = self + .shortstatehash_statediff + .get(&shortstatehash.to_be_bytes())? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = + utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length"); + let parent = if parent != 0 { Some(parent) } else { None }; + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = size_of::<u64>(); + while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::<u64>(); + continue; + } + if add_mode { + added.insert(v.try_into().expect("we checked the size above")); + } else { + removed.insert(v.try_into().expect("we checked the size above")); + } + i += 2 * size_of::<u64>(); + } + + Ok(StateDiff { + parent, + added, + removed, + }) + } + + fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { + let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); + for new in &diff.added { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in &diff.removed { + value.extend_from_slice(&removed[..]); + } + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value) + } +} diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs new file mode 100644 index 0000000..336317d --- /dev/null +++ b/src/database/key_value/rooms/timeline.rs @@ -0,0 +1,370 @@ +use std::{collections::hash_map, mem::size_of, sync::Arc}; + +use ruma::{ + api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId, +}; +use tracing::error; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; + +impl service::rooms::timeline::Data for KeyValueDatabase { + fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Look for PDUs in that room. + self.pduid_pdu + .iter_from(&prefix, false) + .filter(|(k, _)| k.starts_with(&prefix)) + .map(|(_, pdu)| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid first PDU in db.")) + .map(Arc::new) + }) + .next() + .transpose() + } + + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> { + match self + .lasttimelinecount_cache + .lock() + .unwrap() + .entry(room_id.to_owned()) + { + hash_map::Entry::Vacant(v) => { + if let Some(last_count) = self + .pdus_until(sender_user, room_id, u64::MAX)? + .filter_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) + .map(|(pduid, _)| self.pdu_count(&pduid)) + .next() + { + Ok(*v.insert(last_count?)) + } else { + Ok(0) + } + } + hash_map::Entry::Occupied(o) => Ok(*o.get()), + } + } + + /// Returns the `count` of this pdu's id. + fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pdu_id| self.pdu_count(&pdu_id)) + .transpose() + } + + /// Returns the json of a pdu. + fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the json of a pdu. + fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu's id. + fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { + self.eventid_pduid.get(event_id.as_bytes()) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { + if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { + return Ok(Some(Arc::clone(p))); + } + + if let Some(pdu) = self + .eventid_pduid + .get(event_id.as_bytes())? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + .map(Arc::new) + }) + .transpose()? + { + self.pdu_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } + } + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. + fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } + + /// Returns the `count` of this pdu's id. + fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> { + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) + } + + fn append_pdu( + &self, + pdu_id: &[u8], + pdu: &PduEvent, + json: &CanonicalJsonObject, + count: u64, + ) -> Result<()> { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + )?; + + self.lasttimelinecount_cache + .lock() + .unwrap() + .insert(pdu.room_id.clone(), count); + + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; + + Ok(()) + } + + /// Removes a pdu and creates a new one with the same id. + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { + if self.pduid_pdu.get(pdu_id)?.is_some() { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(pdu).expect("CanonicalJsonObject is always a valid"), + )?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::NotFound, + "PDU does not exist.", + )) + } + } + + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in chronological order. + fn pdus_since<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + since: u64, + ) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Skip the first pdu if it's exactly at since, because we sent that last time + let mut first_pdu_id = prefix.clone(); + first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); + + let user_id = user_id.to_owned(); + + Ok(Box::new( + self.pduid_pdu + .iter_from(&first_pdu_id, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::<PduEvent>(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + }), + )) + } + + /// Returns an iterator over all events and their tokens in a room that happened before the + /// event with id `until` in reverse-chronological order. + fn pdus_until<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + until: u64, + ) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> { + // Create the first part of the full pdu id + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(Box::new( + self.pduid_pdu + .iter_from(current, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::<PduEvent>(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + }), + )) + } + + fn pdus_after<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + from: u64, + ) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> { + // Create the first part of the full pdu id + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut current = prefix.clone(); + current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event + + let current: &[u8] = ¤t; + + let user_id = user_id.to_owned(); + + Ok(Box::new( + self.pduid_pdu + .iter_from(current, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::<PduEvent>(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((pdu_id, pdu)) + }), + )) + } + + fn increment_notification_counts( + &self, + room_id: &RoomId, + notifies: Vec<OwnedUserId>, + highlights: Vec<OwnedUserId>, + ) -> Result<()> { + let mut notifies_batch = Vec::new(); + let mut highlights_batch = Vec::new(); + for user in notifies { + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + notifies_batch.push(userroom_id); + } + for user in highlights { + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + highlights_batch.push(userroom_id); + } + + self.userroomid_notificationcount + .increment_batch(&mut notifies_batch.into_iter())?; + self.userroomid_highlightcount + .increment_batch(&mut highlights_batch.into_iter())?; + Ok(()) + } +} diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs new file mode 100644 index 0000000..4c43572 --- /dev/null +++ b/src/database/key_value/rooms/user.rs @@ -0,0 +1,149 @@ +use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; + +use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; + +impl service::rooms::user::Data for KeyValueDatabase { + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xff); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + self.userroomid_notificationcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_highlightcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + + self.roomuserid_lastnotificationread.insert( + &roomuser_id, + &services().globals.next_count()?.to_be_bytes(), + )?; + + Ok(()) + } + + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_notificationcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid notification count in db.")) + }) + .unwrap_or(Ok(0)) + } + + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_highlightcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid highlight count in db.")) + }) + .unwrap_or(Ok(0)) + } + + fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + Ok(self + .roomuserid_lastnotificationread + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") + }) + }) + .transpose()? + .unwrap_or(0)) + } + + fn associate_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + shortstatehash: u64, + ) -> Result<()> { + let shortroomid = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .insert(&key, &shortstatehash.to_be_bytes()) + } + + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { + let shortroomid = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists"); + + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); + + self.roomsynctoken_shortstatehash + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") + }) + }) + .transpose() + } + + fn get_shared_rooms<'a>( + &'a self, + users: Vec<OwnedUserId>, + ) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> { + let iterators = users.into_iter().map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + self.userroomid_joined + .scan_prefix(prefix) + .map(|(key, _)| { + let roomid_index = key + .iter() + .enumerate() + .find(|(_, &b)| b == 0xff) + .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .0 + + 1; // +1 because the room id starts AFTER the separator + + let room_id = key[roomid_index..].to_vec(); + + Ok::<_, Error>(room_id) + }) + .filter_map(|r| r.ok()) + }); + + // We use the default compare function because keys are sorted correctly (not reversed) + Ok(Box::new( + utils::common_elements(iterators, Ord::cmp) + .expect("users is not empty") + .map(|bytes| { + RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid RoomId bytes in userroomid_joined") + })?) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + }), + )) + } +} diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs new file mode 100644 index 0000000..3fc3e04 --- /dev/null +++ b/src/database/key_value/sending.rs @@ -0,0 +1,205 @@ +use ruma::{ServerName, UserId}; + +use crate::{ + database::KeyValueDatabase, + service::{ + self, + sending::{OutgoingKind, SendingEventType}, + }, + services, utils, Error, Result, +}; + +impl service::sending::Data for KeyValueDatabase { + fn active_requests<'a>( + &'a self, + ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> { + Box::new( + self.servercurrentevent_data + .iter() + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), + ) + } + + fn active_requests_for<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> { + let prefix = outgoing_kind.get_prefix(); + Box::new( + self.servercurrentevent_data + .scan_prefix(prefix) + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), + ) + } + + fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { + self.servercurrentevent_data.remove(&key) + } + + fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + let prefix = outgoing_kind.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { + self.servercurrentevent_data.remove(&key)?; + } + + Ok(()) + } + + fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + let prefix = outgoing_kind.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { + self.servercurrentevent_data.remove(&key).unwrap(); + } + + for (key, _) in self.servernameevent_data.scan_prefix(prefix) { + self.servernameevent_data.remove(&key).unwrap(); + } + + Ok(()) + } + + fn queue_requests( + &self, + requests: &[(&OutgoingKind, SendingEventType)], + ) -> Result<Vec<Vec<u8>>> { + let mut batch = Vec::new(); + let mut keys = Vec::new(); + for (outgoing_kind, event) in requests { + let mut key = outgoing_kind.get_prefix(); + if let SendingEventType::Pdu(value) = &event { + key.extend_from_slice(value) + } else { + key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()) + } + let value = if let SendingEventType::Edu(value) = &event { + &**value + } else { + &[] + }; + batch.push((key.clone(), value.to_owned())); + keys.push(key); + } + self.servernameevent_data + .insert_batch(&mut batch.into_iter())?; + Ok(keys) + } + + fn queued_requests<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> { + let prefix = outgoing_kind.get_prefix(); + return Box::new( + self.servernameevent_data + .scan_prefix(prefix) + .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), + ); + } + + fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> { + for (e, key) in events { + let value = if let SendingEventType::Edu(value) = &e { + &**value + } else { + &[] + }; + self.servercurrentevent_data.insert(key, value)?; + self.servernameevent_data.remove(key)?; + } + + Ok(()) + } + + fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + self.servername_educount + .insert(server_name.as_bytes(), &last_count.to_be_bytes()) + } + + fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> { + self.servername_educount + .get(server_name.as_bytes())? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) + }) + } +} + +#[tracing::instrument(skip(key))] +fn parse_servercurrentevent( + key: &[u8], + value: Vec<u8>, +) -> Result<(OutgoingKind, SendingEventType)> { + // Appservices start with a plus + Ok::<_, Error>(if key.starts_with(b"+") { + let mut parts = key[1..].splitn(2, |&b| b == 0xff); + + let server = parts.next().expect("splitn always returns one element"); + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + + let server = utils::string_from_bytes(server).map_err(|_| { + Error::bad_database("Invalid server bytes in server_currenttransaction") + })?; + + ( + OutgoingKind::Appservice(server), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + SendingEventType::Edu(value) + }, + ) + } else if key.starts_with(b"$") { + let mut parts = key[1..].splitn(3, |&b| b == 0xff); + + let user = parts.next().expect("splitn always returns one element"); + let user_string = utils::string_from_bytes(user) + .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; + let user_id = UserId::parse(user_string) + .map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; + + let pushkey = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let pushkey_string = utils::string_from_bytes(pushkey) + .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; + + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + + ( + OutgoingKind::Push(user_id, pushkey_string), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + // I'm pretty sure this should never be called + SendingEventType::Edu(value) + }, + ) + } else { + let mut parts = key.splitn(2, |&b| b == 0xff); + + let server = parts.next().expect("splitn always returns one element"); + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + + let server = utils::string_from_bytes(server).map_err(|_| { + Error::bad_database("Invalid server bytes in server_currenttransaction") + })?; + + ( + OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { + Error::bad_database("Invalid server string in server_currenttransaction") + })?), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + SendingEventType::Edu(value) + }, + ) + }) +} diff --git a/src/database/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index ed0970d..2ea6ad4 100644 --- a/src/database/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,16 +1,9 @@ -use std::sync::Arc; - -use crate::Result; use ruma::{DeviceId, TransactionId, UserId}; -use super::abstraction::Tree; - -pub struct TransactionIds { - pub(super) userdevicetxnid_response: Arc<dyn Tree>, // Response can be empty (/sendToDevice) or the event id (/send) -} +use crate::{database::KeyValueDatabase, service, Result}; -impl TransactionIds { - pub fn add_txnid( +impl service::transaction_ids::Data for KeyValueDatabase { + fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, @@ -28,7 +21,7 @@ impl TransactionIds { Ok(()) } - pub fn existing_txnid( + fn existing_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs new file mode 100644 index 0000000..5fd91b0 --- /dev/null +++ b/src/database/key_value/uiaa.rs @@ -0,0 +1,89 @@ +use ruma::{ + api::client::{error::ErrorKind, uiaa::UiaaInfo}, + CanonicalJsonValue, DeviceId, UserId, +}; + +use crate::{database::KeyValueDatabase, service, Error, Result}; + +impl service::uiaa::Data for KeyValueDatabase { + fn set_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + request: &CanonicalJsonValue, + ) -> Result<()> { + self.userdevicesessionid_uiaarequest + .write() + .unwrap() + .insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); + + Ok(()) + } + + fn get_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Option<CanonicalJsonValue> { + self.userdevicesessionid_uiaarequest + .read() + .unwrap() + .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) + .map(|j| j.to_owned()) + } + + fn update_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + uiaainfo: Option<&UiaaInfo>, + ) -> Result<()> { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + if let Some(uiaainfo) = uiaainfo { + self.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + )?; + } else { + self.userdevicesessionid_uiaainfo + .remove(&userdevicesessionid)?; + } + + Ok(()) + } + + fn get_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Result<UiaaInfo> { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xff); + userdevicesessionid.extend_from_slice(session.as_bytes()); + + serde_json::from_slice( + &self + .userdevicesessionid_uiaainfo + .get(&userdevicesessionid)? + .ok_or(Error::BadRequest( + ErrorKind::Forbidden, + "UIAA session does not exist.", + ))?, + ) + .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + } +} diff --git a/src/database/users.rs b/src/database/key_value/users.rs index 7c15f1d..1cabab0 100644 --- a/src/database/users.rs +++ b/src/database/key_value/users.rs @@ -1,50 +1,29 @@ -use crate::{utils, Error, Result}; +use std::{collections::BTreeMap, mem::size_of}; + use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition}, + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::{AnyToDeviceEvent, StateEventType}, serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId, - UInt, UserId, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, + OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use std::{collections::BTreeMap, mem, sync::Arc}; use tracing::warn; -use super::abstraction::Tree; - -pub struct Users { - pub(super) userid_password: Arc<dyn Tree>, - pub(super) userid_displayname: Arc<dyn Tree>, - pub(super) userid_avatarurl: Arc<dyn Tree>, - pub(super) userid_blurhash: Arc<dyn Tree>, - pub(super) userdeviceid_token: Arc<dyn Tree>, - pub(super) userdeviceid_metadata: Arc<dyn Tree>, // This is also used to check if a device exists - pub(super) userid_devicelistversion: Arc<dyn Tree>, // DevicelistVersion = u64 - pub(super) token_userdeviceid: Arc<dyn Tree>, - - pub(super) onetimekeyid_onetimekeys: Arc<dyn Tree>, // OneTimeKeyId = UserId + DeviceKeyId - pub(super) userid_lastonetimekeyupdate: Arc<dyn Tree>, // LastOneTimeKeyUpdate = Count - pub(super) keychangeid_userid: Arc<dyn Tree>, // KeyChangeId = UserId/RoomId + Count - pub(super) keyid_key: Arc<dyn Tree>, // KeyId = UserId + KeyId (depends on key type) - pub(super) userid_masterkeyid: Arc<dyn Tree>, - pub(super) userid_selfsigningkeyid: Arc<dyn Tree>, - pub(super) userid_usersigningkeyid: Arc<dyn Tree>, - - pub(super) userfilterid_filter: Arc<dyn Tree>, // UserFilterId = UserId + FilterId - - pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count -} +use crate::{ + database::KeyValueDatabase, + service::{self, users::clean_signatures}, + services, utils, Error, Result, +}; -impl Users { +impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. - #[tracing::instrument(skip(self, user_id))] - pub fn exists(&self, user_id: &UserId) -> Result<bool> { + fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } /// Check if account is deactivated - #[tracing::instrument(skip(self, user_id))] - pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { + fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { Ok(self .userid_password .get(user_id.as_bytes())? @@ -55,37 +34,13 @@ impl Users { .is_empty()) } - /// Check if a user is an admin - #[tracing::instrument(skip(self, user_id, rooms, globals))] - pub fn is_admin( - &self, - user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, - ) -> Result<bool> { - let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); - - rooms.is_joined(user_id, &admin_room_id) - } - - /// Create a new user account on this homeserver. - #[tracing::instrument(skip(self, user_id, password))] - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.set_password(user_id, password)?; - Ok(()) - } - /// Returns the number of users registered on this server. - #[tracing::instrument(skip(self))] - pub fn count(&self) -> Result<usize> { + fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) } /// Find out which user an access token belongs to. - #[tracing::instrument(skip(self, token))] - pub fn find_from_token(&self, token: &str) -> Result<Option<(Box<UserId>, String)>> { + fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> { self.token_userdeviceid .get(token.as_bytes())? .map_or(Ok(None), |bytes| { @@ -112,55 +67,29 @@ impl Users { } /// Returns an iterator over all users on this homeserver. - #[tracing::instrument(skip(self))] - pub fn iter(&self) -> impl Iterator<Item = Result<Box<UserId>>> + '_ { - self.userid_password.iter().map(|(bytes, _)| { + fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { + Box::new(self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in userid_password is invalid unicode.") })?) .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) - }) + })) } /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. - #[tracing::instrument(skip(self))] - pub fn list_local_users(&self) -> Result<Vec<String>> { + fn list_local_users(&self) -> Result<Vec<String>> { let users: Vec<String> = self .userid_password .iter() - .filter_map(|(username, pw)| self.get_username_with_valid_password(&username, &pw)) + .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) .collect(); Ok(users) } - /// Will only return with Some(username) if the password was not empty and the - /// username could be successfully parsed. - /// If utils::string_from_bytes(...) returns an error that username will be skipped - /// and the error will be logged. - #[tracing::instrument(skip(self))] - fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option<String> { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!( - "Failed to parse username while calling get_local_users(): {}", - e.to_string() - ); - None - } - } - } - } - /// Returns the password hash for the given user. - #[tracing::instrument(skip(self, user_id))] - pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { + fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { self.userid_password .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -171,10 +100,9 @@ impl Users { } /// Hash and set the user's password to the Argon2 hash - #[tracing::instrument(skip(self, user_id, password))] - pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { - if let Ok(hash) = utils::calculate_hash(password) { + if let Ok(hash) = utils::calculate_password_hash(password) { self.userid_password .insert(user_id.as_bytes(), hash.as_bytes())?; Ok(()) @@ -191,8 +119,7 @@ impl Users { } /// Returns the displayname of a user on this homeserver. - #[tracing::instrument(skip(self, user_id))] - pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { + fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { self.userid_displayname .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -203,8 +130,7 @@ impl Users { } /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - #[tracing::instrument(skip(self, user_id, displayname))] - pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { + fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { if let Some(displayname) = displayname { self.userid_displayname .insert(user_id.as_bytes(), displayname.as_bytes())?; @@ -216,8 +142,7 @@ impl Users { } /// Get the avatar_url of a user. - #[tracing::instrument(skip(self, user_id))] - pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<Box<MxcUri>>> { + fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> { self.userid_avatarurl .get(user_id.as_bytes())? .map(|bytes| { @@ -230,8 +155,7 @@ impl Users { } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, avatar_url))] - pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<Box<MxcUri>>) -> Result<()> { + fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> { if let Some(avatar_url) = avatar_url { self.userid_avatarurl .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; @@ -243,8 +167,7 @@ impl Users { } /// Get the blurhash of a user. - #[tracing::instrument(skip(self, user_id))] - pub fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> { + fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> { self.userid_blurhash .get(user_id.as_bytes())? .map(|bytes| { @@ -257,8 +180,7 @@ impl Users { } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, blurhash))] - pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { + fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { if let Some(blurhash) = blurhash { self.userid_blurhash .insert(user_id.as_bytes(), blurhash.as_bytes())?; @@ -270,8 +192,7 @@ impl Users { } /// Adds a new device to a user. - #[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))] - pub fn create_device( + fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -305,8 +226,7 @@ impl Users { } /// Removes a device from a user. - #[tracing::instrument(skip(self, user_id, device_id))] - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -336,31 +256,32 @@ impl Users { } /// Returns an iterator over all device ids of this user. - #[tracing::instrument(skip(self, user_id))] - pub fn all_device_ids<'a>( + fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a { + ) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); // All devices have metadata - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xff) - .next() - .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? - .into()) - }) + Box::new( + self.userdeviceid_metadata + .scan_prefix(prefix) + .map(|(bytes, _)| { + Ok(utils::string_from_bytes( + bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { + Error::bad_database("UserDevice ID in db is invalid.") + })?, + ) + .map_err(|_| { + Error::bad_database("Device ID in userdeviceid_metadata is invalid.") + })? + .into()) + }), + ) } /// Replaces the access token of one device. - #[tracing::instrument(skip(self, user_id, device_id, token))] - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -383,21 +304,12 @@ impl Users { Ok(()) } - #[tracing::instrument(skip( - self, - user_id, - device_id, - one_time_key_key, - one_time_key_value, - globals - ))] - pub fn add_one_time_key( + fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw<OneTimeKey>, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -421,14 +333,15 @@ impl Users { &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), )?; - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + self.userid_lastonetimekeyupdate.insert( + user_id.as_bytes(), + &services().globals.next_count()?.to_be_bytes(), + )?; Ok(()) } - #[tracing::instrument(skip(self, user_id))] - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { + fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { self.userid_lastonetimekeyupdate .get(user_id.as_bytes())? .map(|bytes| { @@ -439,14 +352,12 @@ impl Users { .unwrap_or(Ok(0)) } - #[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))] - pub fn take_one_time_key( + fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - globals: &super::globals::Globals, - ) -> Result<Option<(Box<DeviceKeyId>, Raw<OneTimeKey>)>> { + ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); prefix.extend_from_slice(device_id.as_bytes()); @@ -455,8 +366,10 @@ impl Users { prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); prefix.push(b':'); - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + self.userid_lastonetimekeyupdate.insert( + user_id.as_bytes(), + &services().globals.next_count()?.to_be_bytes(), + )?; self.onetimekeyid_onetimekeys .scan_prefix(prefix) @@ -466,21 +379,19 @@ impl Users { Ok(( serde_json::from_slice( - &*key - .rsplit(|&b| b == 0xff) + key.rsplit(|&b| b == 0xff) .next() .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, ) .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, - serde_json::from_slice(&*value) + serde_json::from_slice(&value) .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, )) }) .transpose() } - #[tracing::instrument(skip(self, user_id, device_id))] - pub fn count_one_time_keys( + fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, @@ -496,8 +407,8 @@ impl Users { .scan_prefix(userdeviceid) .map(|(bytes, _)| { Ok::<_, Error>( - serde_json::from_slice::<Box<DeviceKeyId>>( - &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { + serde_json::from_slice::<OwnedDeviceKeyId>( + bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { Error::bad_database("OneTimeKey ID in db is invalid.") })?, ) @@ -512,14 +423,11 @@ impl Users { Ok(counts) } - #[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))] - pub fn add_device_keys( + fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); @@ -530,27 +438,17 @@ impl Users { &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), )?; - self.mark_device_key_update(user_id, rooms, globals)?; + self.mark_device_key_update(user_id)?; Ok(()) } - #[tracing::instrument(skip( - self, - master_key, - self_signing_key, - user_signing_key, - rooms, - globals - ))] - pub fn add_cross_signing_keys( + fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>, user_signing_key: &Option<Raw<CrossSigningKey>>, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { // TODO: Check signatures @@ -653,20 +551,17 @@ impl Users { .insert(user_id.as_bytes(), &user_signing_key_key)?; } - self.mark_device_key_update(user_id, rooms, globals)?; + self.mark_device_key_update(user_id)?; Ok(()) } - #[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))] - pub fn sign_key( + fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = target_id.as_bytes().to_vec(); key.push(0xff); @@ -684,7 +579,7 @@ impl Users { .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? .as_object_mut() .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? - .entry(sender_id.to_owned()) + .entry(sender_id.to_string()) .or_insert_with(|| serde_json::Map::new().into()); signatures @@ -698,18 +593,17 @@ impl Users { )?; // TODO: Should we notify about this change? - self.mark_device_key_update(target_id, rooms, globals)?; + self.mark_device_key_update(target_id)?; Ok(()) } - #[tracing::instrument(skip(self, user_or_room_id, from, to))] - pub fn keys_changed<'a>( + fn keys_changed<'a>( &'a self, user_or_room_id: &str, from: u64, to: Option<u64>, - ) -> impl Iterator<Item = Result<Box<UserId>>> + 'a { + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { let mut prefix = user_or_room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -718,41 +612,48 @@ impl Users { let to = to.unwrap_or(u64::MAX); - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to + Box::new( + self.keychangeid_userid + .iter_from(&start, false) + .take_while(move |(k, _)| { + k.starts_with(&prefix) + && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { + if let Ok(c) = utils::u64_from_bytes(current) { + c <= to + } else { + warn!("BadDatabase: Could not parse keychangeid_userid bytes"); + false + } } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); + warn!("BadDatabase: Could not parse keychangeid_userid"); false } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) - }) - } - - #[tracing::instrument(skip(self, user_id, rooms, globals))] - pub fn mark_device_key_update( - &self, - user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, - ) -> Result<()> { - let count = globals.next_count()?.to_be_bytes(); - for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { + }) + .map(|(_, bytes)| { + UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database( + "User ID in devicekeychangeid_userid is invalid unicode.", + ) + })?) + .map_err(|_| { + Error::bad_database("User ID in devicekeychangeid_userid is invalid.") + }) + }), + ) + } + + fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + let count = services().globals.next_count()?.to_be_bytes(); + for room_id in services() + .rooms + .state_cache + .rooms_joined(user_id) + .filter_map(|r| r.ok()) + { // Don't send key updates to unencrypted rooms - if rooms + if services() + .rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? .is_none() { @@ -774,8 +675,7 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id))] - pub fn get_device_keys( + fn get_device_keys( &self, user_id: &UserId, device_id: &DeviceId, @@ -791,11 +691,10 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] - pub fn get_master_key<F: Fn(&UserId) -> bool>( + fn get_master_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result<Option<Raw<CrossSigningKey>>> { self.userid_masterkeyid .get(user_id.as_bytes())? @@ -813,11 +712,10 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] - pub fn get_self_signing_key<F: Fn(&UserId) -> bool>( + fn get_self_signing_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result<Option<Raw<CrossSigningKey>>> { self.userid_selfsigningkeyid .get(user_id.as_bytes())? @@ -835,8 +733,7 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { + fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { self.userid_usersigningkeyid .get(user_id.as_bytes())? .map_or(Ok(None), |key| { @@ -848,29 +745,19 @@ impl Users { }) } - #[tracing::instrument(skip( - self, - sender, - target_user_id, - target_device_id, - event_type, - content, - globals - ))] - pub fn add_to_device_event( + fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = target_user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(target_device_id.as_bytes()); key.push(0xff); - key.extend_from_slice(&globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); let mut json = serde_json::Map::new(); json.insert("type".to_owned(), event_type.to_owned().into()); @@ -884,8 +771,7 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id))] - pub fn get_to_device_events( + fn get_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, @@ -907,8 +793,7 @@ impl Users { Ok(events) } - #[tracing::instrument(skip(self, user_id, device_id, until))] - pub fn remove_to_device_events( + fn remove_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, @@ -929,7 +814,7 @@ impl Users { .map(|(key, _)| { Ok::<_, Error>(( key.clone(), - utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()]) + utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()]) .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, )) }) @@ -942,8 +827,7 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id, device))] - pub fn update_device_metadata( + fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, @@ -968,8 +852,7 @@ impl Users { } /// Get device metadata. - #[tracing::instrument(skip(self, user_id, device_id))] - pub fn get_device_metadata( + fn get_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, @@ -987,8 +870,7 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { + fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { self.userid_devicelistversion .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -998,46 +880,26 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] - pub fn all_devices_metadata<'a>( + fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator<Item = Result<Device>> + 'a { + ) -> Box<dyn Iterator<Item = Result<Device>> + 'a> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::<Device>(&bytes) - .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) - }) - } - - /// Deactivate account - #[tracing::instrument(skip(self, user_id))] - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } - - // Set the password to "" to indicate a deactivated account. Hashes will never result in an - // empty string, so the user will not be able to log in again. Systems like changing the - // password without logging in should check if the account is deactivated. - self.userid_password.insert(user_id.as_bytes(), &[])?; - - // TODO: Unhook 3PID - Ok(()) + Box::new( + self.userdeviceid_metadata + .scan_prefix(key) + .map(|(_, bytes)| { + serde_json::from_slice::<Device>(&bytes).map_err(|_| { + Error::bad_database("Device in userdeviceid_metadata is invalid.") + }) + }), + ) } /// Creates a new sync filter. Returns the filter id. - #[tracing::instrument(skip(self))] - pub fn create_filter( - &self, - user_id: &UserId, - filter: &IncomingFilterDefinition, - ) -> Result<String> { + fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> { let filter_id = utils::random_string(4); let mut key = user_id.as_bytes().to_vec(); @@ -1052,12 +914,7 @@ impl Users { Ok(filter_id) } - #[tracing::instrument(skip(self))] - pub fn get_filter( - &self, - user_id: &UserId, - filter_id: &str, - ) -> Result<Option<IncomingFilterDefinition>> { + fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(filter_id.as_bytes()); @@ -1073,29 +930,24 @@ impl Users { } } -/// Ensure that a user only sees signatures from themselves and the target user -fn clean_signatures<F: Fn(&UserId) -> bool>( - cross_signing_key: &mut serde_json::Value, - user_id: &UserId, - allowed_signatures: F, -) -> Result<(), Error> { - if let Some(signatures) = cross_signing_key - .get_mut("signatures") - .and_then(|v| v.as_object_mut()) - { - // Don't allocate for the full size of the current signatures, but require - // at most one resize if nothing is dropped - let new_capacity = signatures.len() / 2; - for (user, signature) in - mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) - { - let id = <&UserId>::try_from(user.as_str()) - .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if id == user_id || allowed_signatures(id) { - signatures.insert(user, signature); +/// Will only return with Some(username) if the password was not empty and the +/// username could be successfully parsed. +/// If utils::string_from_bytes(...) returns an error that username will be skipped +/// and the error will be logged. +fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> { + // A valid password is not empty + if password.is_empty() { + None + } else { + match utils::string_from_bytes(username) { + Ok(u) => Some(u), + Err(e) => { + warn!( + "Failed to parse username while calling get_local_users(): {}", + e.to_string() + ); + None } } } - - Ok(()) } diff --git a/src/database/media.rs b/src/database/media.rs deleted file mode 100644 index a4bb402..0000000 --- a/src/database/media.rs +++ /dev/null @@ -1,358 +0,0 @@ -use crate::database::globals::Globals; -use image::{imageops::FilterType, GenericImageView}; - -use super::abstraction::Tree; -use crate::{utils, Error, Result}; -use std::{mem, sync::Arc}; -use tokio::{ - fs::File, - io::{AsyncReadExt, AsyncWriteExt}, -}; - -pub struct FileMeta { - pub content_disposition: Option<String>, - pub content_type: Option<String>, - pub file: Vec<u8>, -} - -pub struct Media { - pub(super) mediaid_file: Arc<dyn Tree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType -} - -impl Media { - /// Uploads a file. - pub async fn create( - &self, - mxc: String, - globals: &Globals, - content_disposition: &Option<&str>, - content_type: &Option<&str>, - file: &[u8], - ) -> Result<()> { - let mut key = mxc.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail - key.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail - key.push(0xff); - key.extend_from_slice( - content_disposition - .as_ref() - .map(|f| f.as_bytes()) - .unwrap_or_default(), - ); - key.push(0xff); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); - - let path = globals.get_media_file(&key); - let mut f = File::create(path).await?; - f.write_all(file).await?; - - self.mediaid_file.insert(&key, &[])?; - Ok(()) - } - - /// Uploads or replaces a file thumbnail. - #[allow(clippy::too_many_arguments)] - pub async fn upload_thumbnail( - &self, - mxc: String, - globals: &Globals, - content_disposition: &Option<String>, - content_type: &Option<String>, - width: u32, - height: u32, - file: &[u8], - ) -> Result<()> { - let mut key = mxc.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&width.to_be_bytes()); - key.extend_from_slice(&height.to_be_bytes()); - key.push(0xff); - key.extend_from_slice( - content_disposition - .as_ref() - .map(|f| f.as_bytes()) - .unwrap_or_default(), - ); - key.push(0xff); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); - - let path = globals.get_media_file(&key); - let mut f = File::create(path).await?; - f.write_all(file).await?; - - self.mediaid_file.insert(&key, &[])?; - - Ok(()) - } - - /// Downloads a file. - pub async fn get(&self, globals: &Globals, mxc: &str) -> Result<Option<FileMeta>> { - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail - prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail - prefix.push(0xff); - - let first = self.mediaid_file.scan_prefix(prefix).next(); - if let Some((key, _)) = first { - let path = globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; - let mut parts = key.rsplit(|&b| b == 0xff); - - let content_type = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") - }) - }) - .transpose()?; - - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database( - "Content Disposition in mediaid_file is invalid unicode.", - ) - })?, - ) - }; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file, - })) - } else { - Ok(None) - } - } - - /// Returns width, height of the thumbnail and whether it should be cropped. Returns None when - /// the server should send the original file. - pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { - match (width, height) { - (0..=32, 0..=32) => Some((32, 32, true)), - (0..=96, 0..=96) => Some((96, 96, true)), - (0..=320, 0..=240) => Some((320, 240, false)), - (0..=640, 0..=480) => Some((640, 480, false)), - (0..=800, 0..=600) => Some((800, 600, false)), - _ => None, - } - } - - /// Downloads a file's thumbnail. - /// - /// Here's an example on how it works: - /// - /// - Client requests an image with width=567, height=567 - /// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails - /// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) - /// - Server creates the thumbnail and sends it to the user - /// - /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. - pub async fn get_thumbnail( - &self, - mxc: &str, - globals: &Globals, - width: u32, - height: u32, - ) -> Result<Option<FileMeta>> { - let (width, height, crop) = self - .thumbnail_properties(width, height) - .unwrap_or((0, 0, false)); // 0, 0 because that's the original file - - let mut main_prefix = mxc.as_bytes().to_vec(); - main_prefix.push(0xff); - - let mut thumbnail_prefix = main_prefix.clone(); - thumbnail_prefix.extend_from_slice(&width.to_be_bytes()); - thumbnail_prefix.extend_from_slice(&height.to_be_bytes()); - thumbnail_prefix.push(0xff); - - let mut original_prefix = main_prefix; - original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail - original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail - original_prefix.push(0xff); - - let first_thumbnailprefix = self.mediaid_file.scan_prefix(thumbnail_prefix).next(); - let first_originalprefix = self.mediaid_file.scan_prefix(original_prefix).next(); - if let Some((key, _)) = first_thumbnailprefix { - // Using saved thumbnail - let path = globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; - let mut parts = key.rsplit(|&b| b == 0xff); - - let content_type = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") - }) - }) - .transpose()?; - - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database("Content Disposition in db is invalid.") - })?, - ) - }; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.to_vec(), - })) - } else if let Some((key, _)) = first_originalprefix { - // Generate a thumbnail - let path = globals.get_media_file(&key); - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; - - let mut parts = key.rsplit(|&b| b == 0xff); - - let content_type = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") - }) - }) - .transpose()?; - - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database( - "Content Disposition in mediaid_file is invalid unicode.", - ) - })?, - ) - }; - - if let Ok(image) = image::load_from_memory(&file) { - let original_width = image.width(); - let original_height = image.height(); - if width > original_width || height > original_height { - return Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.to_vec(), - })); - } - - let thumbnail = if crop { - image.resize_to_fill(width, height, FilterType::CatmullRom) - } else { - let (exact_width, exact_height) = { - // Copied from image::dynimage::resize_dimensions - let ratio = u64::from(original_width) * u64::from(height); - let nratio = u64::from(width) * u64::from(original_height); - - let use_width = nratio <= ratio; - let intermediate = if use_width { - u64::from(original_height) * u64::from(width) - / u64::from(original_width) - } else { - u64::from(original_width) * u64::from(height) - / u64::from(original_height) - }; - if use_width { - if intermediate <= u64::from(::std::u32::MAX) { - (width, intermediate as u32) - } else { - ( - (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) - as u32, - ::std::u32::MAX, - ) - } - } else if intermediate <= u64::from(::std::u32::MAX) { - (intermediate as u32, height) - } else { - ( - ::std::u32::MAX, - (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) - as u32, - ) - } - }; - - image.thumbnail_exact(exact_width, exact_height) - }; - - let mut thumbnail_bytes = Vec::new(); - thumbnail.write_to(&mut thumbnail_bytes, image::ImageOutputFormat::Png)?; - - // Save thumbnail in database so we don't have to generate it again next time - let mut thumbnail_key = key.to_vec(); - let width_index = thumbnail_key - .iter() - .position(|&b| b == 0xff) - .ok_or_else(|| Error::bad_database("Media in db is invalid."))? - + 1; - let mut widthheight = width.to_be_bytes().to_vec(); - widthheight.extend_from_slice(&height.to_be_bytes()); - - thumbnail_key.splice( - width_index..width_index + 2 * mem::size_of::<u32>(), - widthheight, - ); - - let path = globals.get_media_file(&thumbnail_key); - let mut f = File::create(path).await?; - f.write_all(&thumbnail_bytes).await?; - - self.mediaid_file.insert(&thumbnail_key, &[])?; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file: thumbnail_bytes.to_vec(), - })) - } else { - // Couldn't parse file to generate thumbnail, send original - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.to_vec(), - })) - } - } else { - Ok(None) - } - } -} diff --git a/src/database/mod.rs b/src/database/mod.rs new file mode 100644 index 0000000..78bb358 --- /dev/null +++ b/src/database/mod.rs @@ -0,0 +1,1007 @@ +pub mod abstraction; +pub mod key_value; + +use crate::{services, utils, Config, Error, PduEvent, Result, Services, SERVICES}; +use abstraction::{KeyValueDatabaseEngine, KvTree}; +use directories::ProjectDirs; +use lru_cache::LruCache; +use ruma::{ + events::{ + push_rules::{PushRulesEvent, PushRulesEventContent}, + room::message::RoomMessageEventContent, + GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, + }, + push::Ruleset, + CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, + UserId, +}; +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + fs::{self, remove_dir_all}, + io::Write, + mem::size_of, + path::Path, + sync::{Arc, Mutex, RwLock}, +}; + +use tracing::{debug, error, info, warn}; + +pub struct KeyValueDatabase { + _db: Arc<dyn KeyValueDatabaseEngine>, + + //pub globals: globals::Globals, + pub(super) global: Arc<dyn KvTree>, + pub(super) server_signingkeys: Arc<dyn KvTree>, + + //pub users: users::Users, + pub(super) userid_password: Arc<dyn KvTree>, + pub(super) userid_displayname: Arc<dyn KvTree>, + pub(super) userid_avatarurl: Arc<dyn KvTree>, + pub(super) userid_blurhash: Arc<dyn KvTree>, + pub(super) userdeviceid_token: Arc<dyn KvTree>, + pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists + pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64 + pub(super) token_userdeviceid: Arc<dyn KvTree>, + + pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId + pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count + pub(super) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count + pub(super) keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type) + pub(super) userid_masterkeyid: Arc<dyn KvTree>, + pub(super) userid_selfsigningkeyid: Arc<dyn KvTree>, + pub(super) userid_usersigningkeyid: Arc<dyn KvTree>, + + pub(super) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId + + pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count + + //pub uiaa: uiaa::Uiaa, + pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication + pub(super) userdevicesessionid_uiaarequest: + RwLock<BTreeMap<(OwnedUserId, OwnedDeviceId, String), CanonicalJsonValue>>, + + //pub edus: RoomEdus, + pub(super) readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId + pub(super) roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count + pub(super) roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count + pub(super) typingid_userid: Arc<dyn KvTree>, // TypingId = RoomId + TimeoutTime + Count + pub(super) roomid_lasttypingupdate: Arc<dyn KvTree>, // LastRoomTypingUpdate = Count + pub(super) presenceid_presence: Arc<dyn KvTree>, // PresenceId = RoomId + Count + UserId + pub(super) userid_lastpresenceupdate: Arc<dyn KvTree>, // LastPresenceUpdate = Count + + //pub rooms: rooms::Rooms, + pub(super) pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count + pub(super) eventid_pduid: Arc<dyn KvTree>, + pub(super) roomid_pduleaves: Arc<dyn KvTree>, + pub(super) alias_roomid: Arc<dyn KvTree>, + pub(super) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count + pub(super) publicroomids: Arc<dyn KvTree>, + + pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount + + /// Participating servers in a room. + pub(super) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName + pub(super) serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId + + pub(super) userroomid_joined: Arc<dyn KvTree>, + pub(super) roomuserid_joined: Arc<dyn KvTree>, + pub(super) roomid_joinedcount: Arc<dyn KvTree>, + pub(super) roomid_invitedcount: Arc<dyn KvTree>, + pub(super) roomuseroncejoinedids: Arc<dyn KvTree>, + pub(super) userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>> + pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count + pub(super) userroomid_leftstate: Arc<dyn KvTree>, + pub(super) roomuserid_leftcount: Arc<dyn KvTree>, + + pub(super) disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled + + pub(super) lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId + + pub(super) userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64 + pub(super) userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64 + pub(super) roomuserid_lastnotificationread: Arc<dyn KvTree>, // LastNotificationRead = u64 + + /// Remember the current state hash of a room. + pub(super) roomid_shortstatehash: Arc<dyn KvTree>, + pub(super) roomsynctoken_shortstatehash: Arc<dyn KvTree>, + /// Remember the state hash at events in the past. + pub(super) shorteventid_shortstatehash: Arc<dyn KvTree>, + /// StateKey = EventType + StateKey, ShortStateKey = Count + pub(super) statekey_shortstatekey: Arc<dyn KvTree>, + pub(super) shortstatekey_statekey: Arc<dyn KvTree>, + + pub(super) roomid_shortroomid: Arc<dyn KvTree>, + + pub(super) shorteventid_eventid: Arc<dyn KvTree>, + pub(super) eventid_shorteventid: Arc<dyn KvTree>, + + pub(super) statehash_shortstatehash: Arc<dyn KvTree>, + pub(super) shortstatehash_statediff: Arc<dyn KvTree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) + + pub(super) shorteventid_authchain: Arc<dyn KvTree>, + + /// RoomId + EventId -> outlier PDU. + /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. + pub(super) eventid_outlierpdu: Arc<dyn KvTree>, + pub(super) softfailedeventids: Arc<dyn KvTree>, + + /// RoomId + EventId -> Parent PDU EventId. + pub(super) referencedevents: Arc<dyn KvTree>, + + //pub account_data: account_data::AccountData, + pub(super) roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type + pub(super) roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type + + //pub media: media::Media, + pub(super) mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType + //pub key_backups: key_backups::KeyBackups, + pub(super) backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count) + pub(super) backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count) + pub(super) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId + + //pub transaction_ids: transaction_ids::TransactionIds, + pub(super) userdevicetxnid_response: Arc<dyn KvTree>, // Response can be empty (/sendToDevice) or the event id (/send) + //pub sending: sending::Sending, + pub(super) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync + pub(super) servernameevent_data: Arc<dyn KvTree>, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content + pub(super) servercurrentevent_data: Arc<dyn KvTree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content + + //pub appservice: appservice::Appservice, + pub(super) id_appserviceregistrations: Arc<dyn KvTree>, + + //pub pusher: pusher::PushData, + pub(super) senderkey_pusher: Arc<dyn KvTree>, + + pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>, + pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>, + pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, + pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, + pub(super) eventidshort_cache: Mutex<LruCache<OwnedEventId, u64>>, + pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>, + pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>, + pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>, + pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>, + pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, u64>>, +} + +impl KeyValueDatabase { + /// Tries to remove the old database but ignores all errors. + pub fn try_remove(server_name: &str) -> Result<()> { + let mut path = ProjectDirs::from("xyz", "koesters", "conduit") + .ok_or_else(|| Error::bad_config("The OS didn't return a valid home directory path."))? + .data_dir() + .to_path_buf(); + path.push(server_name); + let _ = remove_dir_all(path); + + Ok(()) + } + + fn check_db_setup(config: &Config) -> Result<()> { + let path = Path::new(&config.database_path); + + let sled_exists = path.join("db").exists(); + let sqlite_exists = path.join("conduit.db").exists(); + let rocksdb_exists = path.join("IDENTITY").exists(); + + let mut count = 0; + + if sled_exists { + count += 1; + } + + if sqlite_exists { + count += 1; + } + + if rocksdb_exists { + count += 1; + } + + if count > 1 { + warn!("Multiple databases at database_path detected"); + return Ok(()); + } + + if sled_exists && config.database_backend != "sled" { + return Err(Error::bad_config( + "Found sled at database_path, but is not specified in config.", + )); + } + + if sqlite_exists && config.database_backend != "sqlite" { + return Err(Error::bad_config( + "Found sqlite at database_path, but is not specified in config.", + )); + } + + if rocksdb_exists && config.database_backend != "rocksdb" { + return Err(Error::bad_config( + "Found rocksdb at database_path, but is not specified in config.", + )); + } + + Ok(()) + } + + /// Load an existing database or create a new one. + pub async fn load_or_create(config: Config) -> Result<()> { + Self::check_db_setup(&config)?; + + if !Path::new(&config.database_path).exists() { + std::fs::create_dir_all(&config.database_path) + .map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?; + } + + let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend { + "sqlite" => { + #[cfg(not(feature = "sqlite"))] + return Err(Error::BadConfig("Database backend not found.")); + #[cfg(feature = "sqlite")] + Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?) + } + "rocksdb" => { + #[cfg(not(feature = "rocksdb"))] + return Err(Error::BadConfig("Database backend not found.")); + #[cfg(feature = "rocksdb")] + Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?) + } + "persy" => { + #[cfg(not(feature = "persy"))] + return Err(Error::BadConfig("Database backend not found.")); + #[cfg(feature = "persy")] + Arc::new(Arc::<abstraction::persy::Engine>::open(&config)?) + } + _ => { + return Err(Error::BadConfig("Database backend not found.")); + } + }; + + if config.max_request_size < 1024 { + error!(?config.max_request_size, "Max request size is less than 1KB. Please increase it."); + } + + let db_raw = Box::new(Self { + _db: builder.clone(), + userid_password: builder.open_tree("userid_password")?, + userid_displayname: builder.open_tree("userid_displayname")?, + userid_avatarurl: builder.open_tree("userid_avatarurl")?, + userid_blurhash: builder.open_tree("userid_blurhash")?, + userdeviceid_token: builder.open_tree("userdeviceid_token")?, + userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, + userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, + token_userdeviceid: builder.open_tree("token_userdeviceid")?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, + userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, + keychangeid_userid: builder.open_tree("keychangeid_userid")?, + keyid_key: builder.open_tree("keyid_key")?, + userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, + userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, + userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, + userfilterid_filter: builder.open_tree("userfilterid_filter")?, + todeviceid_events: builder.open_tree("todeviceid_events")?, + + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, + userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), + readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, + roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt + roomuserid_lastprivatereadupdate: builder + .open_tree("roomuserid_lastprivatereadupdate")?, + typingid_userid: builder.open_tree("typingid_userid")?, + roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, + presenceid_presence: builder.open_tree("presenceid_presence")?, + userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, + pduid_pdu: builder.open_tree("pduid_pdu")?, + eventid_pduid: builder.open_tree("eventid_pduid")?, + roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, + + alias_roomid: builder.open_tree("alias_roomid")?, + aliasid_alias: builder.open_tree("aliasid_alias")?, + publicroomids: builder.open_tree("publicroomids")?, + + tokenids: builder.open_tree("tokenids")?, + + roomserverids: builder.open_tree("roomserverids")?, + serverroomids: builder.open_tree("serverroomids")?, + userroomid_joined: builder.open_tree("userroomid_joined")?, + roomuserid_joined: builder.open_tree("roomuserid_joined")?, + roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, + roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, + roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, + userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, + roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, + userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, + roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + + disabledroomids: builder.open_tree("disabledroomids")?, + + lazyloadedids: builder.open_tree("lazyloadedids")?, + + userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, + userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, + roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, + + statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, + shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, + + shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, + + roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, + + shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, + eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, + shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, + shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, + roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, + roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, + statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, + + eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, + softfailedeventids: builder.open_tree("softfailedeventids")?, + + referencedevents: builder.open_tree("referencedevents")?, + roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, + roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, + mediaid_file: builder.open_tree("mediaid_file")?, + backupid_algorithm: builder.open_tree("backupid_algorithm")?, + backupid_etag: builder.open_tree("backupid_etag")?, + backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, + userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, + servername_educount: builder.open_tree("servername_educount")?, + servernameevent_data: builder.open_tree("servernameevent_data")?, + servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, + id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, + senderkey_pusher: builder.open_tree("senderkey_pusher")?, + global: builder.open_tree("global")?, + server_signingkeys: builder.open_tree("server_signingkeys")?, + + cached_registrations: Arc::new(RwLock::new(HashMap::new())), + pdu_cache: Mutex::new(LruCache::new( + config + .pdu_cache_capacity + .try_into() + .expect("pdu cache capacity fits into usize"), + )), + auth_chain_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + shorteventid_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + eventidshort_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + shortstatekey_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + statekeyshort_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + our_real_users_cache: RwLock::new(HashMap::new()), + appservice_in_room_cache: RwLock::new(HashMap::new()), + lasttimelinecount_cache: Mutex::new(HashMap::new()), + }); + + let db = Box::leak(db_raw); + + let services_raw = Box::new(Services::build(db, config)?); + + // This is the first and only time we initialize the SERVICE static + *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); + + // Matrix resource ownership is based on the server name; changing it + // requires recreating the database from scratch. + if services().users.count()? > 0 { + let conduit_user = + UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + if !services().users.exists(&conduit_user)? { + error!( + "The {} server user does not exist, and the database is not new.", + conduit_user + ); + return Err(Error::bad_database( + "Cannot reuse an existing database after changing the server name, please delete the old one first." + )); + } + } + + // If the database has any data, perform data migrations before starting + let latest_database_version = 12; + + if services().users.count()? > 0 { + // MIGRATIONS + if services().globals.database_version()? < 1 { + for (roomserverid, _) in db.roomserverids.iter() { + let mut parts = roomserverid.split(|&b| b == 0xff); + let room_id = parts.next().expect("split always returns one element"); + let servername = match parts.next() { + Some(s) => s, + None => { + error!("Migration: Invalid roomserverid in db."); + continue; + } + }; + let mut serverroomid = servername.to_vec(); + serverroomid.push(0xff); + serverroomid.extend_from_slice(room_id); + + db.serverroomids.insert(&serverroomid, &[])?; + } + + services().globals.bump_database_version(1)?; + + warn!("Migration: 0 -> 1 finished"); + } + + if services().globals.database_version()? < 2 { + // We accidentally inserted hashed versions of "" into the db instead of just "" + for (userid, password) in db.userid_password.iter() { + let password = utils::string_from_bytes(&password); + + let empty_hashed_password = password.map_or(false, |password| { + argon2::verify_encoded(&password, b"").unwrap_or(false) + }); + + if empty_hashed_password { + db.userid_password.insert(&userid, b"")?; + } + } + + services().globals.bump_database_version(2)?; + + warn!("Migration: 1 -> 2 finished"); + } + + if services().globals.database_version()? < 3 { + // Move media to filesystem + for (key, content) in db.mediaid_file.iter() { + if content.is_empty() { + continue; + } + + let path = services().globals.get_media_file(&key); + let mut file = fs::File::create(path)?; + file.write_all(&content)?; + db.mediaid_file.insert(&key, &[])?; + } + + services().globals.bump_database_version(3)?; + + warn!("Migration: 2 -> 3 finished"); + } + + if services().globals.database_version()? < 4 { + // Add federated users to services() as deactivated + for our_user in services().users.iter() { + let our_user = our_user?; + if services().users.is_deactivated(&our_user)? { + continue; + } + for room in services().rooms.state_cache.rooms_joined(&our_user) { + for user in services().rooms.state_cache.room_members(&room?) { + let user = user?; + if user.server_name() != services().globals.server_name() { + info!(?user, "Migration: creating user"); + services().users.create(&user, None)?; + } + } + } + } + + services().globals.bump_database_version(4)?; + + warn!("Migration: 3 -> 4 finished"); + } + + if services().globals.database_version()? < 5 { + // Upgrade user data store + for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { + let mut parts = roomuserdataid.split(|&b| b == 0xff); + let room_id = parts.next().unwrap(); + let user_id = parts.next().unwrap(); + let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap(); + + let mut key = room_id.to_vec(); + key.push(0xff); + key.extend_from_slice(user_id); + key.push(0xff); + key.extend_from_slice(event_type); + + db.roomusertype_roomuserdataid + .insert(&key, &roomuserdataid)?; + } + + services().globals.bump_database_version(5)?; + + warn!("Migration: 4 -> 5 finished"); + } + + if services().globals.database_version()? < 6 { + // Set room member count + for (roomid, _) in db.roomid_shortstatehash.iter() { + let string = utils::string_from_bytes(&roomid).unwrap(); + let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); + services().rooms.state_cache.update_joined_count(room_id)?; + } + + services().globals.bump_database_version(6)?; + + warn!("Migration: 5 -> 6 finished"); + } + + if services().globals.database_version()? < 7 { + // Upgrade state store + let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new(); + let mut current_sstatehash: Option<u64> = None; + let mut current_room = None; + let mut current_state = HashSet::new(); + let mut counter = 0; + + let mut handle_state = + |current_sstatehash: u64, + current_room: &RoomId, + current_state: HashSet<_>, + last_roomstates: &mut HashMap<_, _>| { + counter += 1; + let last_roomsstatehash = last_roomstates.get(current_room); + + let states_parents = last_roomsstatehash.map_or_else( + || Ok(Vec::new()), + |&last_roomsstatehash| { + services() + .rooms + .state_compressor + .load_shortstatehash_info(last_roomsstatehash) + }, + )?; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = current_state + .difference(&parent_stateinfo.1) + .copied() + .collect::<HashSet<_>>(); + + let statediffremoved = parent_stateinfo + .1 + .difference(¤t_state) + .copied() + .collect::<HashSet<_>>(); + + (statediffnew, statediffremoved) + } else { + (current_state, HashSet::new()) + }; + + services().rooms.state_compressor.save_state_from_diff( + current_sstatehash, + statediffnew, + statediffremoved, + 2, // every state change is 2 event changes on average + states_parents, + )?; + + /* + let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; + let state = tmp.pop().unwrap(); + println!( + "{}\t{}{:?}: {:?} + {:?} - {:?}", + current_room, + " ".repeat(tmp.len()), + utils::u64_from_bytes(¤t_sstatehash).unwrap(), + tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), + state + .2 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) + .collect::<Vec<_>>(), + state + .3 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) + .collect::<Vec<_>>() + ); + */ + + Ok::<_, Error>(()) + }; + + for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { + let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]) + .expect("number of bytes is correct"); + let sstatekey = k[size_of::<u64>()..].to_vec(); + if Some(sstatehash) != current_sstatehash { + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_deref().unwrap(), + current_state, + &mut last_roomstates, + )?; + last_roomstates + .insert(current_room.clone().unwrap(), current_sstatehash); + } + current_state = HashSet::new(); + current_sstatehash = Some(sstatehash); + + let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); + let string = utils::string_from_bytes(&event_id).unwrap(); + let event_id = <&EventId>::try_from(string.as_str()).unwrap(); + let pdu = services() + .rooms + .timeline + .get_pdu(event_id) + .unwrap() + .unwrap(); + + if Some(&pdu.room_id) != current_room.as_ref() { + current_room = Some(pdu.room_id.clone()); + } + } + + let mut val = sstatekey; + val.extend_from_slice(&seventid); + current_state.insert(val.try_into().expect("size is correct")); + } + + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_deref().unwrap(), + current_state, + &mut last_roomstates, + )?; + } + + services().globals.bump_database_version(7)?; + + warn!("Migration: 6 -> 7 finished"); + } + + if services().globals.database_version()? < 8 { + // Generate short room ids for all rooms + for (room_id, _) in db.roomid_shortstatehash.iter() { + let shortroomid = services().globals.next_count()?.to_be_bytes(); + db.roomid_shortroomid.insert(&room_id, &shortroomid)?; + info!("Migration: 8"); + } + // Update pduids db layout + let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(2, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_key = short_room_id; + new_key.extend_from_slice(count); + + Some((new_key, v)) + }); + + db.pduid_pdu.insert_batch(&mut batch)?; + + let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { + if !value.starts_with(b"!") { + return None; + } + let mut parts = value.splitn(2, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_value = short_room_id; + new_value.extend_from_slice(count); + + Some((k, new_value)) + }); + + db.eventid_pduid.insert_batch(&mut batch2)?; + + services().globals.bump_database_version(8)?; + + warn!("Migration: 7 -> 8 finished"); + } + + if services().globals.database_version()? < 9 { + // Update tokenids db layout + let mut iter = db + .tokenids + .iter() + .filter_map(|(key, _)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(4, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let word = parts.next().unwrap(); + let _pdu_id_room = parts.next().unwrap(); + let pdu_id_count = parts.next().unwrap(); + + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); + let mut new_key = short_room_id; + new_key.extend_from_slice(word); + new_key.push(0xff); + new_key.extend_from_slice(pdu_id_count); + Some((new_key, Vec::new())) + }) + .peekable(); + + while iter.peek().is_some() { + db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; + debug!("Inserted smaller batch"); + } + + info!("Deleting starts"); + + let batch2: Vec<_> = db + .tokenids + .iter() + .filter_map(|(key, _)| { + if key.starts_with(b"!") { + Some(key) + } else { + None + } + }) + .collect(); + + for key in batch2 { + db.tokenids.remove(&key)?; + } + + services().globals.bump_database_version(9)?; + + warn!("Migration: 8 -> 9 finished"); + } + + if services().globals.database_version()? < 10 { + // Add other direction for shortstatekeys + for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { + db.shortstatekey_statekey + .insert(&shortstatekey, &statekey)?; + } + + // Force E2EE device list updates so we can send them over federation + for user_id in services().users.iter().filter_map(|r| r.ok()) { + services().users.mark_device_key_update(&user_id)?; + } + + services().globals.bump_database_version(10)?; + + warn!("Migration: 9 -> 10 finished"); + } + + if services().globals.database_version()? < 11 { + db._db + .open_tree("userdevicesessionid_uiaarequest")? + .clear()?; + services().globals.bump_database_version(11)?; + + warn!("Migration: 10 -> 11 finished"); + } + + if services().globals.database_version()? < 12 { + for username in services().users.list_local_users().unwrap() { + let user = + UserId::parse_with_server_name(username, services().globals.server_name()) + .unwrap(); + + let raw_rules_list = services() + .account_data + .get( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + ) + .unwrap() + .expect("Username is invalid"); + + let mut account_data = + serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); + let rules_list = &mut account_data.content.global; + + //content rule + { + let content_rule_transformation = + [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; + + let rule = rules_list.content.get(content_rule_transformation[0]); + if rule.is_some() { + let mut rule = rule.unwrap().clone(); + rule.rule_id = content_rule_transformation[1].to_owned(); + rules_list.content.remove(content_rule_transformation[0]); + rules_list.content.insert(rule); + } + } + + //underride rules + { + let underride_rule_transformation = [ + [".m.rules.call", ".m.rule.call"], + [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], + [ + ".m.rules.encrypted_room_one_to_one", + ".m.rule.encrypted_room_one_to_one", + ], + [".m.rules.message", ".m.rule.message"], + [".m.rules.encrypted", ".m.rule.encrypted"], + ]; + + for transformation in underride_rule_transformation { + let rule = rules_list.underride.get(transformation[0]); + if let Some(rule) = rule { + let mut rule = rule.clone(); + rule.rule_id = transformation[1].to_owned(); + rules_list.underride.remove(transformation[0]); + rules_list.underride.insert(rule); + } + } + } + + services().account_data.update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; + } + + services().globals.bump_database_version(12)?; + + warn!("Migration: 11 -> 12 finished"); + } + + assert_eq!( + services().globals.database_version().unwrap(), + latest_database_version + ); + + info!( + "Loaded {} database with version {}", + services().globals.config.database_backend, + latest_database_version + ); + } else { + services() + .globals + .bump_database_version(latest_database_version)?; + + // Create the admin room and server user on first run + services().admin.create_admin_room().await?; + + warn!( + "Created new {} database with version {}", + services().globals.config.database_backend, + latest_database_version + ); + } + + // This data is probably outdated + db.presenceid_presence.clear()?; + + services().admin.start_handler(); + + // Set emergency access for the conduit user + match set_emergency_access() { + Ok(pwd_set) => { + if pwd_set { + warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"); + services().admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!")); + } + } + Err(e) => { + error!( + "Could not set the configured emergency password for the conduit user: {}", + e + ) + } + }; + + services().sending.start_handler(); + + Self::start_cleanup_task().await; + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn flush(&self) -> Result<()> { + let start = std::time::Instant::now(); + + let res = self._db.flush(); + + debug!("flush: took {:?}", start.elapsed()); + + res + } + + #[tracing::instrument] + pub async fn start_cleanup_task() { + use tokio::time::interval; + + #[cfg(unix)] + use tokio::signal::unix::{signal, SignalKind}; + + use std::time::{Duration, Instant}; + + let timer_interval = + Duration::from_secs(services().globals.config.cleanup_second_interval as u64); + + tokio::spawn(async move { + let mut i = interval(timer_interval); + #[cfg(unix)] + let mut s = signal(SignalKind::hangup()).unwrap(); + + loop { + #[cfg(unix)] + tokio::select! { + _ = i.tick() => { + debug!("cleanup: Timer ticked"); + } + _ = s.recv() => { + debug!("cleanup: Received SIGHUP"); + } + }; + #[cfg(not(unix))] + { + i.tick().await; + debug!("cleanup: Timer ticked") + } + + let start = Instant::now(); + if let Err(e) = services().globals.cleanup() { + error!("cleanup: Errored: {}", e); + } else { + debug!("cleanup: Finished in {:?}", start.elapsed()); + } + } + }); + } +} + +/// Sets the emergency password and push rules for the @conduit account in case emergency password is set +fn set_emergency_access() -> Result<bool> { + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is a valid UserId"); + + services().users.set_password( + &conduit_user, + services().globals.emergency_password().as_deref(), + )?; + + let (ruleset, res) = match services().globals.emergency_password() { + Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), + None => (Ruleset::new(), Ok(false)), + }; + + services().account_data.update( + None, + &conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { global: ruleset }, + }) + .expect("to json value always works"), + )?; + + res +} diff --git a/src/database/pusher.rs b/src/database/pusher.rs deleted file mode 100644 index 6b906c2..0000000 --- a/src/database/pusher.rs +++ /dev/null @@ -1,348 +0,0 @@ -use crate::{Database, Error, PduEvent, Result}; -use bytes::BytesMut; -use ruma::{ - api::{ - client::push::{get_pushers, set_pusher, PusherKind}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - }, - events::{ - room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, - AnySyncRoomEvent, RoomEventType, StateEventType, - }, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, - serde::Raw, - uint, RoomId, UInt, UserId, -}; -use tracing::{error, info, warn}; - -use std::{fmt::Debug, mem, sync::Arc}; - -use super::abstraction::Tree; - -pub struct PushData { - /// UserId + pushkey -> Pusher - pub(super) senderkey_pusher: Arc<dyn Tree>, -} - -impl PushData { - #[tracing::instrument(skip(self, sender, pusher))] - pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { - let mut key = sender.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pusher.pushkey.as_bytes()); - - // There are 2 kinds of pushers but the spec says: null deletes the pusher. - if pusher.kind.is_none() { - return self - .senderkey_pusher - .remove(&key) - .map(|_| ()) - .map_err(Into::into); - } - - self.senderkey_pusher.insert( - &key, - &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), - )?; - - Ok(()) - } - - #[tracing::instrument(skip(self, senderkey))] - pub fn get_pusher(&self, senderkey: &[u8]) -> Result<Option<get_pushers::v3::Pusher>> { - self.senderkey_pusher - .get(senderkey)? - .map(|push| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .transpose() - } - - #[tracing::instrument(skip(self, sender))] - pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .collect() - } - - #[tracing::instrument(skip(self, sender))] - pub fn get_pusher_senderkeys<'a>( - &'a self, - sender: &UserId, - ) -> impl Iterator<Item = Vec<u8>> + 'a { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) - } -} - -#[tracing::instrument(skip(globals, destination, request))] -pub async fn send_request<T: OutgoingRequest>( - globals: &crate::database::globals::Globals, - destination: &str, - request: T, -) -> Result<T::IncomingResponse> -where - T: Debug, -{ - let destination = destination.replace("/_matrix/push/v1/notify", ""); - - let http_request = request - .try_into_http_request::<BytesMut>( - &destination, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(|body| body.freeze()); - - let reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); - - // TODO: we could keep this very short and let expo backoff do it's thing... - //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); - - let url = reqwest_request.url().clone(); - let response = globals.default_client().execute(reqwest_request).await; - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - info!( - "Push gateway returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - crate::utils::string_from_bytes(&body) - ); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - info!( - "Push gateway returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Push gateway returned bad response.") - }) - } - Err(e) => Err(e.into()), - } -} - -#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] -pub async fn send_push_notice( - user: &UserId, - unread: UInt, - pusher: &get_pushers::v3::Pusher, - ruleset: Ruleset, - pdu: &PduEvent, - db: &Database, -) -> Result<()> { - let mut notify = None; - let mut tweaks = Vec::new(); - - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - for action in get_actions( - user, - &ruleset, - &power_levels, - &pdu.to_sync_room_event(), - &pdu.room_id, - db, - )? { - let n = match action { - Action::DontNotify => false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => true, - Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - } - }; - - if notify.is_some() { - return Err(Error::bad_database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, - )); - } - - notify = Some(n); - } - - if notify == Some(true) { - send_notice(unread, pusher, tweaks, pdu, db).await?; - } - // Else the event triggered no actions - - Ok(()) -} - -#[tracing::instrument(skip(user, ruleset, pdu, db))] -pub fn get_actions<'a>( - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevelsEventContent, - pdu: &Raw<AnySyncRoomEvent>, - room_id: &RoomId, - db: &Database, -) -> Result<&'a [Action]> { - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: 10_u32.into(), // TODO: get member count efficiently - user_display_name: db - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), - users_power_levels: power_levels.users.clone(), - default_power_level: power_levels.users_default, - notification_power_levels: power_levels.notifications.clone(), - }; - - Ok(ruleset.get_actions(pdu, &ctx)) -} - -#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] -async fn send_notice( - unread: UInt, - pusher: &get_pushers::v3::Pusher, - tweaks: Vec<Tweak>, - event: &PduEvent, - db: &Database, -) -> Result<()> { - // TODO: email - if pusher.kind == PusherKind::Email { - return Ok(()); - } - - // TODO: - // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info - // 2. can pusher/devices have conflicting formats - let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); - let url = if let Some(url) = &pusher.data.url { - url - } else { - error!("Http Pusher must have URL specified."); - return Ok(()); - }; - - let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); - let mut data_minus_url = pusher.data.clone(); - // The url must be stripped off according to spec - data_minus_url.url = None; - device.data = data_minus_url; - - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks = tweaks.clone(); - } - - let d = &[device]; - let mut notifi = Notification::new(d); - - notifi.prio = NotificationPriority::Low; - notifi.event_id = Some(&event.event_id); - notifi.room_id = Some(&event.room_id); - // TODO: missed calls - notifi.counts = NotificationCounts::new(unread, uint!(0)); - - if event.kind == RoomEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) - { - notifi.prio = NotificationPriority::High - } - - if event_id_only { - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } else { - notifi.sender = Some(&event.sender); - notifi.event_type = Some(&event.kind); - let content = serde_json::value::to_raw_value(&event.content).ok(); - notifi.content = content.as_deref(); - - if event.kind == RoomEventType::RoomMember { - notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); - } - - let user_name = db.users.displayname(&event.sender)?; - notifi.sender_display_name = user_name.as_deref(); - - let room_name = if let Some(room_name_pdu) = - db.rooms - .room_state_get(&event.room_id, &StateEventType::RoomName, "")? - { - serde_json::from_str::<RoomNameEventContent>(room_name_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid room name event in database."))? - .name - } else { - None - }; - - notifi.room_name = room_name.as_deref(); - - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } - - // TODO: email - - Ok(()) -} diff --git a/src/database/rooms.rs b/src/database/rooms.rs deleted file mode 100644 index 4ad815e..0000000 --- a/src/database/rooms.rs +++ /dev/null @@ -1,3503 +0,0 @@ -mod edus; - -pub use edus::RoomEdus; - -use crate::{ - pdu::{EventHash, PduBuilder}, - utils, Database, Error, PduEvent, Result, -}; -use lru_cache::LruCache; -use regex::Regex; -use ring::digest; -use ruma::{ - api::{client::error::ErrorKind, federation}, - events::{ - direct::DirectEvent, - ignored_user_list::IgnoredUserListEvent, - push_rules::PushRulesEvent, - room::{ - create::RoomCreateEventContent, - member::{MembershipState, RoomMemberEventContent}, - power_levels::RoomPowerLevelsEventContent, - }, - tag::TagEvent, - AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, - RoomAccountDataEventType, RoomEventType, StateEventType, - }, - push::{Action, Ruleset, Tweak}, - serde::{CanonicalJsonObject, CanonicalJsonValue, Raw}, - state_res::{self, RoomVersion, StateMap}, - uint, DeviceId, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, -}; -use serde::Deserialize; -use serde_json::value::to_raw_value; -use std::{ - borrow::Cow, - collections::{hash_map, BTreeMap, HashMap, HashSet}, - fmt::Debug, - iter, - mem::size_of, - sync::{Arc, Mutex, RwLock}, -}; -use tokio::sync::MutexGuard; -use tracing::{error, warn}; - -use super::{abstraction::Tree, pusher}; - -/// The unique identifier of each state group. -/// -/// This is created when a state group is added to the database by -/// hashing the entire state. -pub type StateHashId = Vec<u8>; -pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()]; - -pub struct Rooms { - pub edus: RoomEdus, - pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = ShortRoomId + Count - pub(super) eventid_pduid: Arc<dyn Tree>, - pub(super) roomid_pduleaves: Arc<dyn Tree>, - pub(super) alias_roomid: Arc<dyn Tree>, - pub(super) aliasid_alias: Arc<dyn Tree>, // AliasId = RoomId + Count - pub(super) publicroomids: Arc<dyn Tree>, - - pub(super) tokenids: Arc<dyn Tree>, // TokenId = ShortRoomId + Token + PduIdCount - - /// Participating servers in a room. - pub(super) roomserverids: Arc<dyn Tree>, // RoomServerId = RoomId + ServerName - pub(super) serverroomids: Arc<dyn Tree>, // ServerRoomId = ServerName + RoomId - - pub(super) userroomid_joined: Arc<dyn Tree>, - pub(super) roomuserid_joined: Arc<dyn Tree>, - pub(super) roomid_joinedcount: Arc<dyn Tree>, - pub(super) roomid_invitedcount: Arc<dyn Tree>, - pub(super) roomuseroncejoinedids: Arc<dyn Tree>, - pub(super) userroomid_invitestate: Arc<dyn Tree>, // InviteState = Vec<Raw<Pdu>> - pub(super) roomuserid_invitecount: Arc<dyn Tree>, // InviteCount = Count - pub(super) userroomid_leftstate: Arc<dyn Tree>, - pub(super) roomuserid_leftcount: Arc<dyn Tree>, - - pub(super) disabledroomids: Arc<dyn Tree>, // Rooms where incoming federation handling is disabled - - pub(super) lazyloadedids: Arc<dyn Tree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId - - pub(super) userroomid_notificationcount: Arc<dyn Tree>, // NotifyCount = u64 - pub(super) userroomid_highlightcount: Arc<dyn Tree>, // HightlightCount = u64 - - /// Remember the current state hash of a room. - pub(super) roomid_shortstatehash: Arc<dyn Tree>, - pub(super) roomsynctoken_shortstatehash: Arc<dyn Tree>, - /// Remember the state hash at events in the past. - pub(super) shorteventid_shortstatehash: Arc<dyn Tree>, - /// StateKey = EventType + StateKey, ShortStateKey = Count - pub(super) statekey_shortstatekey: Arc<dyn Tree>, - pub(super) shortstatekey_statekey: Arc<dyn Tree>, - - pub(super) roomid_shortroomid: Arc<dyn Tree>, - - pub(super) shorteventid_eventid: Arc<dyn Tree>, - pub(super) eventid_shorteventid: Arc<dyn Tree>, - - pub(super) statehash_shortstatehash: Arc<dyn Tree>, - pub(super) shortstatehash_statediff: Arc<dyn Tree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) - - pub(super) shorteventid_authchain: Arc<dyn Tree>, - - /// RoomId + EventId -> outlier PDU. - /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. - pub(super) eventid_outlierpdu: Arc<dyn Tree>, - pub(super) softfailedeventids: Arc<dyn Tree>, - - /// RoomId + EventId -> Parent PDU EventId. - pub(super) referencedevents: Arc<dyn Tree>, - - pub(super) pdu_cache: Mutex<LruCache<Box<EventId>, Arc<PduEvent>>>, - pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, - pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, - pub(super) eventidshort_cache: Mutex<LruCache<Box<EventId>, u64>>, - pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>, - pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>, - pub(super) our_real_users_cache: RwLock<HashMap<Box<RoomId>, Arc<HashSet<Box<UserId>>>>>, - pub(super) appservice_in_room_cache: RwLock<HashMap<Box<RoomId>, HashMap<String, bool>>>, - pub(super) lazy_load_waiting: - Mutex<HashMap<(Box<UserId>, Box<DeviceId>, Box<RoomId>, u64), HashSet<Box<UserId>>>>, - pub(super) stateinfo_cache: Mutex< - LruCache< - u64, - Vec<( - u64, // sstatehash - HashSet<CompressedStateEvent>, // full state - HashSet<CompressedStateEvent>, // added - HashSet<CompressedStateEvent>, // removed - )>, - >, - >, - pub(super) lasttimelinecount_cache: Mutex<HashMap<Box<RoomId>, u64>>, -} - -impl Rooms { - /// Returns true if a given room version is supported - #[tracing::instrument(skip(self, db))] - pub fn is_supported_version(&self, db: &Database, room_version: &RoomVersionId) -> bool { - db.globals.supported_room_versions().contains(room_version) - } - - /// Builds a StateMap by iterating over all keys that start - /// with state_hash, this gives the full state for the given state_hash. - #[tracing::instrument(skip(self))] - pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - let mut result = BTreeMap::new(); - let mut i = 0; - for compressed in full_state.into_iter() { - let parsed = self.parse_compressed_state_event(compressed)?; - result.insert(parsed.0, parsed.1); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - Ok(result) - } - - #[tracing::instrument(skip(self))] - pub async fn state_full( - &self, - shortstatehash: u64, - ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - - let mut result = HashMap::new(); - let mut i = 0; - for compressed in full_state { - let (_, eventid) = self.parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); - } - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - - Ok(result) - } - - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn state_get_id( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result<Option<Arc<EventId>>> { - let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { - Some(s) => s, - None => return Ok(None), - }; - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - Ok(full_state - .into_iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) - } - - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn state_get( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result<Option<Arc<PduEvent>>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) - } - - /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { - self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database( - "Invalid shortstatehash bytes in shorteventid_shortstatehash", - ) - }) - }) - .transpose() - }) - } - - /// Returns the last state hash key added to the db for the given room. - #[tracing::instrument(skip(self))] - pub fn current_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) - } - - /// This fetches auth events from the current state. - #[tracing::instrument(skip(self))] - pub fn get_auth_events( - &self, - room_id: &RoomId, - kind: &RoomEventType, - sender: &UserId, - state_key: Option<&str>, - content: &serde_json::value::RawValue, - ) -> Result<StateMap<Arc<PduEvent>>> { - let shortstatehash = - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - current_shortstatehash - } else { - return Ok(HashMap::new()); - }; - - let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content) - .expect("content is a valid JSON object"); - - let mut sauthevents = auth_events - .into_iter() - .filter_map(|(event_type, state_key)| { - self.get_shortstatekey(&event_type.to_string().into(), &state_key) - .ok() - .flatten() - .map(|s| (s, (event_type, state_key))) - }) - .collect::<HashMap<_, _>>(); - - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - - Ok(full_state - .into_iter() - .filter_map(|compressed| self.parse_compressed_state_event(compressed).ok()) - .filter_map(|(shortstatekey, event_id)| { - sauthevents.remove(&shortstatekey).map(|k| (k, event_id)) - }) - .filter_map(|(k, event_id)| self.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu))) - .collect()) - } - - /// Generate a new StateHash. - /// - /// A unique hash made from hashing all PDU ids of the state joined with 0xff. - fn calculate_hash(&self, bytes_list: &[&[u8]]) -> StateHashId { - // We only hash the pdu's event ids, not the whole pdu - let bytes = bytes_list.join(&0xff); - let hash = digest::digest(&digest::SHA256, &bytes); - hash.as_ref().into() - } - - /// Checks if a room exists. - #[tracing::instrument(skip(self))] - pub fn exists(&self, room_id: &RoomId) -> Result<bool> { - let prefix = match self.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; - - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } - - /// Checks if a room exists. - #[tracing::instrument(skip(self))] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Look for PDUs in that room. - self.pduid_pdu - .iter_from(&prefix, false) - .filter(|(k, _)| k.starts_with(&prefix)) - .map(|(_, pdu)| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid first PDU in db.")) - .map(Arc::new) - }) - .next() - .transpose() - } - - /// Force the creation of a new StateHash and insert it into the db. - /// - /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot. - #[tracing::instrument(skip(self, new_state_ids_compressed, db))] - pub fn force_state( - &self, - room_id: &RoomId, - new_state_ids_compressed: HashSet<CompressedStateEvent>, - db: &Database, - ) -> Result<()> { - let previous_shortstatehash = self.current_shortstatehash(room_id)?; - - let state_hash = self.calculate_hash( - &new_state_ids_compressed - .iter() - .map(|bytes| &bytes[..]) - .collect::<Vec<_>>(), - ); - - let (new_shortstatehash, already_existed) = - self.get_or_create_shortstatehash(&state_hash, &db.globals)?; - - if Some(new_shortstatehash) == previous_shortstatehash { - return Ok(()); - } - - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() - { - let statediffnew: HashSet<_> = new_state_ids_compressed - .difference(&parent_stateinfo.1) - .copied() - .collect(); - - let statediffremoved: HashSet<_> = parent_stateinfo - .1 - .difference(&new_state_ids_compressed) - .copied() - .collect(); - - (statediffnew, statediffremoved) - } else { - (new_state_ids_compressed, HashSet::new()) - }; - - if !already_existed { - self.save_state_from_diff( - new_shortstatehash, - statediffnew.clone(), - statediffremoved, - 2, // every state change is 2 event changes on average - states_parents, - )?; - }; - - for event_id in statediffnew.into_iter().filter_map(|new| { - self.parse_compressed_state_event(new) - .ok() - .map(|(_, id)| id) - }) { - let pdu = match self.get_pdu_json(&event_id)? { - Some(pdu) => pdu, - None => continue, - }; - - if pdu.get("type").and_then(|val| val.as_str()) != Some("m.room.member") { - continue; - } - - let pdu: PduEvent = match serde_json::from_str( - &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), - ) { - Ok(pdu) => pdu, - Err(_) => continue, - }; - - #[derive(Deserialize)] - struct ExtractMembership { - membership: MembershipState, - } - - let membership = match serde_json::from_str::<ExtractMembership>(pdu.content.get()) { - Ok(e) => e.membership, - Err(_) => continue, - }; - - let state_key = match pdu.state_key { - Some(k) => k, - None => continue, - }; - - let user_id = match UserId::parse(state_key) { - Ok(id) => id, - Err(_) => continue, - }; - - self.update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; - } - - self.update_joined_count(room_id, db)?; - - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - - Ok(()) - } - - /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self))] - pub fn load_shortstatehash_info( - &self, - shortstatehash: u64, - ) -> Result< - Vec<( - u64, // sstatehash - HashSet<CompressedStateEvent>, // full state - HashSet<CompressedStateEvent>, // added - HashSet<CompressedStateEvent>, // removed - )>, - > { - if let Some(r) = self - .stateinfo_cache - .lock() - .unwrap() - .get_mut(&shortstatehash) - { - return Ok(r.clone()); - } - - let value = self - .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = - utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length"); - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let mut i = size_of::<u64>(); - while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i += size_of::<u64>(); - continue; - } - if add_mode { - added.insert(v.try_into().expect("we checked the size above")); - } else { - removed.insert(v.try_into().expect("we checked the size above")); - } - i += 2 * size_of::<u64>(); - } - - if parent != 0_u64 { - let mut response = self.load_shortstatehash_info(parent)?; - let mut state = response.last().unwrap().1.clone(); - state.extend(added.iter().copied()); - for r in &removed { - state.remove(r); - } - - response.push((shortstatehash, state, added, removed)); - - Ok(response) - } else { - let response = vec![(shortstatehash, added.clone(), added, removed)]; - self.stateinfo_cache - .lock() - .unwrap() - .insert(shortstatehash, response.clone()); - Ok(response) - } - } - - pub fn compress_state_event( - &self, - shortstatekey: u64, - event_id: &EventId, - globals: &super::globals::Globals, - ) -> Result<CompressedStateEvent> { - let mut v = shortstatekey.to_be_bytes().to_vec(); - v.extend_from_slice( - &self - .get_or_create_shorteventid(event_id, globals)? - .to_be_bytes(), - ); - Ok(v.try_into().expect("we checked the size above")) - } - - /// Returns shortstatekey, event id - pub fn parse_compressed_state_event( - &self, - compressed_event: CompressedStateEvent, - ) -> Result<(u64, Arc<EventId>)> { - Ok(( - utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]) - .expect("bytes have right length"), - self.get_eventid_from_short( - utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]) - .expect("bytes have right length"), - )?, - )) - } - - /// Creates a new shortstatehash that often is just a diff to an already existing - /// shortstatehash and therefore very efficient. - /// - /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer - /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0 - /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's - /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too. - /// - /// * `shortstatehash` - Shortstatehash of this state - /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid - /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid - /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer - /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer - #[tracing::instrument(skip( - self, - statediffnew, - statediffremoved, - diff_to_sibling, - parent_states - ))] - pub fn save_state_from_diff( - &self, - shortstatehash: u64, - statediffnew: HashSet<CompressedStateEvent>, - statediffremoved: HashSet<CompressedStateEvent>, - diff_to_sibling: usize, - mut parent_states: Vec<( - u64, // sstatehash - HashSet<CompressedStateEvent>, // full state - HashSet<CompressedStateEvent>, // added - HashSet<CompressedStateEvent>, // removed - )>, - ) -> Result<()> { - let diffsum = statediffnew.len() + statediffremoved.len(); - - if parent_states.len() > 3 { - // Number of layers - // To many layers, we have to go deeper - let parent = parent_states.pop().unwrap(); - - let mut parent_new = parent.2; - let mut parent_removed = parent.3; - - for removed in statediffremoved { - if !parent_new.remove(&removed) { - // It was not added in the parent and we removed it - parent_removed.insert(removed); - } - // Else it was added in the parent and we removed it again. We can forget this change - } - - for new in statediffnew { - if !parent_removed.remove(&new) { - // It was not touched in the parent and we added it - parent_new.insert(new); - } - // Else it was removed in the parent and we added it again. We can forget this change - } - - self.save_state_from_diff( - shortstatehash, - parent_new, - parent_removed, - diffsum, - parent_states, - )?; - - return Ok(()); - } - - if parent_states.is_empty() { - // There is no parent layer, create a new state - let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent - for new in &statediffnew { - value.extend_from_slice(&new[..]); - } - - if !statediffremoved.is_empty() { - warn!("Tried to create new state with removals"); - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value)?; - - return Ok(()); - }; - - // Else we have two options. - // 1. We add the current diff on top of the parent layer. - // 2. We replace a layer above - - let parent = parent_states.pop().unwrap(); - let parent_diff = parent.2.len() + parent.3.len(); - - if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { - // Diff too big, we replace above layer(s) - let mut parent_new = parent.2; - let mut parent_removed = parent.3; - - for removed in statediffremoved { - if !parent_new.remove(&removed) { - // It was not added in the parent and we removed it - parent_removed.insert(removed); - } - // Else it was added in the parent and we removed it again. We can forget this change - } - - for new in statediffnew { - if !parent_removed.remove(&new) { - // It was not touched in the parent and we added it - parent_new.insert(new); - } - // Else it was removed in the parent and we added it again. We can forget this change - } - - self.save_state_from_diff( - shortstatehash, - parent_new, - parent_removed, - diffsum, - parent_states, - )?; - } else { - // Diff small enough, we add diff as layer on top of parent - let mut value = parent.0.to_be_bytes().to_vec(); - for new in &statediffnew { - value.extend_from_slice(&new[..]); - } - - if !statediffremoved.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in &statediffremoved { - value.extend_from_slice(&removed[..]); - } - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value)?; - } - - Ok(()) - } - - /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash( - &self, - state_hash: &StateHashId, - globals: &super::globals::Globals, - ) -> Result<(u64, bool)> { - Ok(match self.statehash_shortstatehash.get(state_hash)? { - Some(shortstatehash) => ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ), - None => { - let shortstatehash = globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - } - }) - } - - pub fn get_or_create_shorteventid( - &self, - event_id: &EventId, - globals: &super::globals::Globals, - ) -> Result<u64> { - if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { - return Ok(*short); - } - - let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { - Some(shorteventid) => utils::u64_from_bytes(&shorteventid) - .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - } - }; - - self.eventidshort_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), short); - - Ok(short) - } - - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) - }) - .transpose() - } - - pub fn get_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - ) -> Result<Option<u64>> { - if let Some(short) = self - .statekeyshort_cache - .lock() - .unwrap() - .get_mut(&(event_type.clone(), state_key.to_owned())) - { - return Ok(Some(*short)); - } - - let mut statekey = event_type.to_string().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(state_key.as_bytes()); - - let short = self - .statekey_shortstatekey - .get(&statekey)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) - }) - .transpose()?; - - if let Some(s) = short { - self.statekeyshort_cache - .lock() - .unwrap() - .insert((event_type.clone(), state_key.to_owned()), s); - } - - Ok(short) - } - - pub fn get_or_create_shortroomid( - &self, - room_id: &RoomId, - globals: &super::globals::Globals, - ) -> Result<u64> { - Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { - Some(short) => utils::u64_from_bytes(&short) - .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, - None => { - let short = globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - } - }) - } - - pub fn get_or_create_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - globals: &super::globals::Globals, - ) -> Result<u64> { - if let Some(short) = self - .statekeyshort_cache - .lock() - .unwrap() - .get_mut(&(event_type.clone(), state_key.to_owned())) - { - return Ok(*short); - } - - let mut statekey = event_type.to_string().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(state_key.as_bytes()); - - let short = match self.statekey_shortstatekey.get(&statekey)? { - Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, - None => { - let shortstatekey = globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey)?; - shortstatekey - } - }; - - self.statekeyshort_cache - .lock() - .unwrap() - .insert((event_type.clone(), state_key.to_owned()), short); - - Ok(short) - } - - pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { - if let Some(id) = self - .shorteventid_cache - .lock() - .unwrap() - .get_mut(&shorteventid) - { - return Ok(Arc::clone(id)); - } - - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - - let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - - self.shorteventid_cache - .lock() - .unwrap() - .insert(shorteventid, Arc::clone(&event_id)); - - Ok(event_id) - } - - pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - if let Some(id) = self - .shortstatekey_cache - .lock() - .unwrap() - .get_mut(&shortstatekey) - { - return Ok(id.clone()); - } - - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - - let mut parts = bytes.splitn(2, |&b| b == 0xff); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - - let event_type = - StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| { - Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?; - - let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { - Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") - })?; - - let result = (event_type, state_key); - - self.shortstatekey_cache - .lock() - .unwrap() - .insert(shortstatekey, result.clone()); - - Ok(result) - } - - /// Returns the full room state. - #[tracing::instrument(skip(self))] - pub async fn room_state_full( - &self, - room_id: &RoomId, - ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } - } - - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get_id( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result<Option<Arc<EventId>>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } - - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result<Option<Arc<PduEvent>>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } - - /// Returns the `count` of this pdu's id. - pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> { - utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) - } - - /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pdu_id| self.pdu_count(&pdu_id)) - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.pduid_pdu - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|b| self.pdu_count(&b.0)) - .transpose() - .map(|op| op.unwrap_or_default()) - } - - /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } - - /// Returns the json of a pdu. - pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } - - /// Returns the json of a pdu. - pub fn get_non_outlier_pdu_json( - &self, - event_id: &EventId, - ) -> Result<Option<CanonicalJsonObject>> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } - - /// Returns the pdu's id. - pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { - self.eventid_pduid.get(event_id.as_bytes()) - } - - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } - - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { - if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { - return Ok(Some(Arc::clone(p))); - } - - if let Some(pdu) = self - .eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) - .map(Arc::new) - }) - .transpose()? - { - self.pdu_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), Arc::clone(&pdu)); - Ok(Some(pdu)) - } else { - Ok(None) - } - } - - /// Returns the pdu. - /// - /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } - - /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } - - /// Removes a pdu and creates a new one with the same id. - #[tracing::instrument(skip(self))] - fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), - )?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::NotFound, - "PDU does not exist.", - )) - } - } - - /// Returns the leaf pdus of a room. - #[tracing::instrument(skip(self))] - pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } - - #[tracing::instrument(skip(self, room_id, event_ids))] - pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { - for prev in event_ids { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; - } - - Ok(()) - } - - /// Replace the leaves of a room. - /// - /// The provided `event_ids` become the new leaves, this allows a room to have multiple - /// `prev_events`. - #[tracing::instrument(skip(self))] - pub fn replace_pdu_leaves<'a>( - &self, - room_id: &RoomId, - event_ids: impl IntoIterator<Item = &'a EventId> + Debug, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } - - for event_id in event_ids { - let mut key = prefix.to_owned(); - key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; - } - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) - } - - /// Returns the pdu from the outlier tree. - pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } - - /// Append the PDU as an outlier. - /// - /// Any event given to this will be processed (state-res) on another thread. - #[tracing::instrument(skip(self, pdu))] - pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) - } - - #[tracing::instrument(skip(self))] - pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) - } - - #[tracing::instrument(skip(self))] - pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) - } - - /// Creates a new persisted data unit and adds it to a room. - /// - /// By this point the incoming event should be fully authenticated, no auth happens - /// in `append_pdu`. - /// - /// Returns pdu id - #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] - pub fn append_pdu<'a>( - &self, - pdu: &PduEvent, - mut pdu_json: CanonicalJsonObject, - leaves: impl IntoIterator<Item = &'a EventId> + Debug, - db: &Database, - ) -> Result<Vec<u8>> { - let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); - - // Make unsigned fields correct. This is not properly documented in the spec, but state - // events need to have previous content in the unsigned field, so clients can easily - // interpret things like membership changes - if let Some(state_key) = &pdu.state_key { - if let CanonicalJsonValue::Object(unsigned) = pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(Default::default())) - { - if let Some(shortstatehash) = self.pdu_shortstatehash(&pdu.event_id).unwrap() { - if let Some(prev_state) = self - .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() - { - unsigned.insert( - "prev_content".to_owned(), - CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.content.clone()) - .expect("event is valid, we just created it"), - ), - ); - } - } - } else { - error!("Invalid unsigned type in pdu."); - } - } - - // We must keep track of all events that have been referenced. - self.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - self.replace_pdu_leaves(&pdu.room_id, leaves)?; - - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(pdu.room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - - let count1 = db.globals.next_count()?; - // Mark as read first so the sending client doesn't get a notification even if appending - // fails - self.edus - .private_read_set(&pdu.room_id, &pdu.sender, count1, &db.globals)?; - self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; - - let count2 = db.globals.next_count()?; - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&count2.to_be_bytes()); - - // There's a brief moment of time here where the count is updated but the pdu does not - // exist. This could theoretically lead to dropped pdus, but it's extremely rare - // - // Update: We fixed this using insert_lock - - self.pduid_pdu.insert( - &pdu_id, - &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - self.lasttimelinecount_cache - .lock() - .unwrap() - .insert(pdu.room_id.clone(), count2); - - self.eventid_pduid - .insert(pdu.event_id.as_bytes(), &pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - - drop(insert_lock); - - // See if the event matches any known pushers - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - let sync_pdu = pdu.to_sync_room_event(); - - let mut notifies = Vec::new(); - let mut highlights = Vec::new(); - - for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { - // Don't notify the user of their own events - if user == &pdu.sender { - continue; - } - - let rules_for_user = db - .account_data - .get( - None, - user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .map(|ev: PushRulesEvent| ev.content.global) - .unwrap_or_else(|| Ruleset::server_default(user)); - - let mut highlight = false; - let mut notify = false; - - for action in pusher::get_actions( - user, - &rules_for_user, - &power_levels, - &sync_pdu, - &pdu.room_id, - db, - )? { - match action { - Action::DontNotify => notify = false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => notify = true, - Action::SetTweak(Tweak::Highlight(true)) => { - highlight = true; - } - _ => {} - }; - } - - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(pdu.room_id.as_bytes()); - - if notify { - notifies.push(userroom_id.clone()); - } - - if highlight { - highlights.push(userroom_id); - } - - for senderkey in db.pusher.get_pusher_senderkeys(user) { - db.sending.send_push_pdu(&*pdu_id, senderkey)?; - } - } - - self.userroomid_notificationcount - .increment_batch(&mut notifies.into_iter())?; - self.userroomid_highlightcount - .increment_batch(&mut highlights.into_iter())?; - - match pdu.kind { - RoomEventType::RoomRedaction => { - if let Some(redact_id) = &pdu.redacts { - self.redact_pdu(redact_id, pdu)?; - } - } - RoomEventType::RoomMember => { - if let Some(state_key) = &pdu.state_key { - #[derive(Deserialize)] - struct ExtractMembership { - membership: MembershipState, - } - - // if the state_key fails - let target_user_id = UserId::parse(state_key.clone()) - .expect("This state_key was previously validated"); - - let content = serde_json::from_str::<ExtractMembership>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - let invite_state = match content.membership { - MembershipState::Invite => { - let state = self.calculate_invite_state(pdu)?; - Some(state) - } - _ => None, - }; - - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth - self.update_membership( - &pdu.room_id, - &target_user_id, - content.membership, - &pdu.sender, - invite_state, - db, - true, - )?; - } - } - RoomEventType::RoomMessage => { - #[derive(Deserialize)] - struct ExtractBody<'a> { - #[serde(borrow)] - body: Option<Cow<'a, str>>, - } - - let content = serde_json::from_str::<ExtractBody<'_>>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - if let Some(body) = content.body { - let mut batch = body - .split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) - .map(str::to_lowercase) - .map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xff); - key.extend_from_slice(&pdu_id); - (key, Vec::new()) - }); - - self.tokenids.insert_batch(&mut batch)?; - - let admin_room = self.id_from_alias( - <&RoomAliasId>::try_from( - format!("#admins:{}", db.globals.server_name()).as_str(), - ) - .expect("#admins:server_name is a valid room alias"), - )?; - let server_user = format!("@conduit:{}", db.globals.server_name()); - - let to_conduit = body.starts_with(&format!("{}: ", server_user)); - - // This will evaluate to false if the emergency password is set up so that - // the administrator can execute commands as conduit - let from_conduit = - pdu.sender == server_user && db.globals.emergency_password().is_none(); - - if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { - db.admin.process_message(body.to_string()); - } - } - } - _ => {} - } - - Ok(pdu_id) - } - - #[tracing::instrument(skip(self))] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> { - match self - .lasttimelinecount_cache - .lock() - .unwrap() - .entry(room_id.to_owned()) - { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(&sender_user, &room_id, u64::MAX)? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .map(|(pduid, _)| self.pdu_count(&pduid)) - .next() - { - Ok(*v.insert(last_count?)) - } else { - Ok(0) - } - } - hash_map::Entry::Occupied(o) => Ok(*o.get()), - } - } - - #[tracing::instrument(skip(self))] - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) - .unwrap_or(Ok(0)) - } - - #[tracing::instrument(skip(self))] - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_highlightcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) - .unwrap_or(Ok(0)) - } - - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, state_ids_compressed, globals))] - pub fn set_event_state( - &self, - event_id: &EventId, - room_id: &RoomId, - state_ids_compressed: HashSet<CompressedStateEvent>, - globals: &super::globals::Globals, - ) -> Result<()> { - let shorteventid = self.get_or_create_shorteventid(event_id, globals)?; - - let previous_shortstatehash = self.current_shortstatehash(room_id)?; - - let state_hash = self.calculate_hash( - &state_ids_compressed - .iter() - .map(|s| &s[..]) - .collect::<Vec<_>>(), - ); - - let (shortstatehash, already_existed) = - self.get_or_create_shortstatehash(&state_hash, globals)?; - - if !already_existed { - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: HashSet<_> = state_ids_compressed - .difference(&parent_stateinfo.1) - .copied() - .collect(); - - let statediffremoved: HashSet<_> = parent_stateinfo - .1 - .difference(&state_ids_compressed) - .copied() - .collect(); - - (statediffnew, statediffremoved) - } else { - (state_ids_compressed, HashSet::new()) - }; - self.save_state_from_diff( - shortstatehash, - statediffnew, - statediffremoved, - 1_000_000, // high number because no state will be based on this one - states_parents, - )?; - } - - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - - Ok(()) - } - - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, new_pdu, globals))] - pub fn append_to_state( - &self, - new_pdu: &PduEvent, - globals: &super::globals::Globals, - ) -> Result<u64> { - let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; - - let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; - - if let Some(p) = previous_shortstatehash { - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; - } - - if let Some(state_key) = &new_pdu.state_key { - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - - let shortstatekey = self.get_or_create_shortstatekey( - &new_pdu.kind.to_string().into(), - state_key, - globals, - )?; - - let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; - - let replaces = states_parents - .last() - .map(|info| { - info.1 - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - }) - .unwrap_or_default(); - - if Some(&new) == replaces { - return Ok(previous_shortstatehash.expect("must exist")); - } - - // TODO: statehash with deterministic inputs - let shortstatehash = globals.next_count()?; - - let mut statediffnew = HashSet::new(); - statediffnew.insert(new); - - let mut statediffremoved = HashSet::new(); - if let Some(replaces) = replaces { - statediffremoved.insert(*replaces); - } - - self.save_state_from_diff( - shortstatehash, - statediffnew, - statediffremoved, - 2, - states_parents, - )?; - - Ok(shortstatehash) - } else { - Ok(previous_shortstatehash.expect("first event in room must be a state event")) - } - } - - #[tracing::instrument(skip(self, invite_event))] - pub fn calculate_invite_state( - &self, - invite_event: &PduEvent, - ) -> Result<Vec<Raw<AnyStrippedStateEvent>>> { - let mut state = Vec::new(); - // Add recommended events - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = self.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = - self.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? - { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = self.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { - state.push(e.to_stripped_state_event()); - } - - state.push(invite_event.to_stripped_state_event()); - Ok(state) - } - - #[tracing::instrument(skip(self))] - pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64) -> Result<()> { - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; - - Ok(()) - } - - pub fn associate_token_shortstatehash( - &self, - room_id: &RoomId, - token: u64, - shortstatehash: u64, - ) -> Result<()> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) - } - - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") - }) - }) - .transpose() - } - - /// Creates a new persisted data unit and adds it to a room. - #[tracing::instrument(skip(self, db, _mutex_lock))] - pub fn build_and_append_pdu( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - db: &Database, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room mutex - ) -> Result<Arc<EventId>> { - let PduBuilder { - event_type, - content, - unsigned, - state_key, - redacts, - } = pdu_builder; - - let prev_events = self - .get_pdu_leaves(room_id)? - .into_iter() - .take(20) - .collect::<Vec<_>>(); - - let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option<RoomCreateEventContent> = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - - // If there was no create event yet, assume we are creating a room with the default - // version right now - let room_version_id = create_event_content - .map_or(db.globals.default_room_version(), |create_event| { - create_event.room_version - }); - let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - - let auth_events = - self.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = unsigned.unwrap_or_default(); - if let Some(state_key) = &state_key { - if let Some(prev_pdu) = - self.room_state_get(room_id, &event_type.to_string().into(), state_key)? - { - unsigned.insert( - "prev_content".to_owned(), - serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), - ); - unsigned.insert( - "prev_sender".to_owned(), - serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), - ); - } - } - - let mut pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: room_id.to_owned(), - sender: sender.to_owned(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind: event_type, - content, - state_key, - prev_events, - depth, - auth_events: auth_events - .iter() - .map(|(_, pdu)| pdu.event_id.clone()) - .collect(), - redacts, - unsigned: if unsigned.is_empty() { - None - } else { - Some(to_raw_value(&unsigned).expect("to_raw_value always works")) - }, - hashes: EventHash { - sha256: "aaa".to_owned(), - }, - signatures: None, - }; - - let auth_check = state_res::auth_check( - &room_version, - &pdu, - None::<PduEvent>, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|e| { - error!("{:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_ref().to_owned()), - ); - - match ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) { - Ok(_) => {} - Err(e) => { - return match e { - ruma::signatures::Error::PduSize => Err(Error::BadRequest( - ErrorKind::TooLarge, - "Message is too long", - )), - _ => Err(Error::BadRequest( - ErrorKind::Unknown, - "Signing event failed", - )), - } - } - } - - // Generate event id - pdu.event_id = EventId::parse_arc(format!( - "${}", - ruma::signatures::reference_hash(&pdu_json, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - pdu_json.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), - ); - - // Generate short event id - let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; - - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = self.append_to_state(&pdu, &db.globals)?; - - let pdu_id = self.append_pdu( - &pdu, - pdu_json, - // Since this PDU references all pdu_leaves we can update the leaves - // of the room - iter::once(&*pdu.event_id), - db, - )?; - - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - self.set_room_state(room_id, statehashid)?; - - let mut servers: HashSet<Box<ServerName>> = - self.room_servers(room_id).filter_map(|r| r.ok()).collect(); - - // In case we are kicking or banning a user, we need to inform their server of the change - if pdu.kind == RoomEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - servers.insert(Box::from(state_key_uid.server_name())); - } - } - - // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above - servers.remove(db.globals.server_name()); - - db.sending.send_pdu(servers.into_iter(), &pdu_id)?; - - for appservice in db.appservice.all()? { - if self.appservice_in_room(room_id, &appservice, db)? { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - continue; - } - - // If the RoomMember event has a non-empty state_key, it is targeted at someone. - // If it is our appservice user, we send this PDU to it. - if pdu.kind == RoomEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - if let Some(appservice_uid) = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, db.globals.server_name()).ok() - }) - { - if state_key_uid == &appservice_uid { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - continue; - } - } - } - } - - if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::<Vec<_>>() - }); - let aliases = namespaces - .get("aliases") - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::<Vec<_>>() - }); - let rooms = namespaces - .get("rooms") - .and_then(|rooms| rooms.as_sequence()); - - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) - || pdu.kind == RoomEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: &Regex| { - self.room_aliases(room_id) - .filter_map(|r| r.ok()) - .any(|room_alias| aliases.is_match(room_alias.as_str())) - }; - - if aliases.iter().any(matching_aliases) - || rooms.map_or(false, |rooms| rooms.contains(&room_id.as_str().into())) - || users.iter().any(matching_users) - { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - } - } - } - - Ok(pdu.event_id) - } - - /// Returns an iterator over all PDUs in a room. - #[tracing::instrument(skip(self))] - pub fn all_pdus<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { - self.pdus_since(user_id, room_id, 0) - } - - /// Returns an iterator over all events in a room that happened after the event with id `since` - /// in chronological order. - #[tracing::instrument(skip(self))] - pub fn pdus_since<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - since: u64, - ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Skip the first pdu if it's exactly at since, because we sent that last time - let mut first_pdu_id = prefix.clone(); - first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(&first_pdu_id, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::<PduEvent>(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) - } - - /// Returns an iterator over all events and their tokens in a room that happened before the - /// event with id `until` in reverse-chronological order. - #[tracing::instrument(skip(self))] - pub fn pdus_until<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - until: u64, - ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::<PduEvent>(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) - } - - /// Returns an iterator over all events and their token in a room that happened after the event - /// with id `from` in chronological order. - #[tracing::instrument(skip(self))] - pub fn pdus_after<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - from: u64, - ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::<PduEvent>(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) - } - - /// Replace a PDU with the redacted form. - #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; - pdu.redact(reason)?; - self.replace_pdu(&pdu_id, &pdu)?; - } - // If event does not exist, just noop - Ok(()) - } - - /// Update current membership data. - #[tracing::instrument(skip(self, last_state, db))] - pub fn update_membership( - &self, - room_id: &RoomId, - user_id: &UserId, - membership: MembershipState, - sender: &UserId, - last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, - db: &Database, - update_joined_count: bool, - ) -> Result<()> { - // Keep track what remote users exist by adding them as "deactivated" users - if user_id.server_name() != db.globals.server_name() { - db.users.create(user_id, None)?; - // TODO: displayname, avatar url - } - - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(user_id.server_name().as_bytes()); - - let mut serverroom_id = user_id.server_name().as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - match &membership { - MembershipState::Join => { - // Check if the user never joined this room - if !self.once_joined(user_id, room_id)? { - // Add the user ID to the join list then - self.roomuseroncejoinedids.insert(&userroom_id, &[])?; - - // Check if the room has a predecessor - if let Some(predecessor) = self - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .and_then(|create| serde_json::from_str(create.content.get()).ok()) - .and_then(|content: RoomCreateEventContent| content.predecessor) - { - // Copy user settings from predecessor to the current room: - // - Push rules - // - // TODO: finish this once push rules are implemented. - // - // let mut push_rules_event_content: PushRulesEvent = account_data - // .get( - // None, - // user_id, - // EventType::PushRules, - // )?; - // - // NOTE: find where `predecessor.room_id` match - // and update to `room_id`. - // - // account_data - // .update( - // None, - // user_id, - // EventType::PushRules, - // &push_rules_event_content, - // globals, - // ) - // .ok(); - - // Copy old tags to new room - if let Some(tag_event) = db.account_data.get::<TagEvent>( - Some(&predecessor.room_id), - user_id, - RoomAccountDataEventType::Tag, - )? { - db.account_data - .update( - Some(room_id), - user_id, - RoomAccountDataEventType::Tag, - &tag_event, - &db.globals, - ) - .ok(); - }; - - // Copy direct chat flag - if let Some(mut direct_event) = db.account_data.get::<DirectEvent>( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - )? { - let mut room_ids_updated = false; - - for room_ids in direct_event.content.0.values_mut() { - if room_ids.iter().any(|r| r == &predecessor.room_id) { - room_ids.push(room_id.to_owned()); - room_ids_updated = true; - } - } - - if room_ids_updated { - db.account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &direct_event, - &db.globals, - )?; - } - }; - } - } - - if update_joined_count { - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - } - MembershipState::Invite => { - // We want to know if the sender is ignored by the receiver - let is_ignored = db - .account_data - .get::<IgnoredUserListEvent>( - None, // Ignored users are in global account data - user_id, // Receiver - GlobalAccountDataEventType::IgnoredUserList - .to_string() - .into(), - )? - .map_or(false, |ignored| { - ignored - .content - .ignored_users - .iter() - .any(|user| user == sender) - }); - - if is_ignored { - return Ok(()); - } - - if update_joined_count { - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()) - .expect("state to bytes always works"), - )?; - self.roomuserid_invitecount - .insert(&roomuser_id, &db.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - } - MembershipState::Leave | MembershipState::Ban => { - if update_joined_count - && self - .room_members(room_id) - .chain(self.room_members_invited(room_id)) - .filter_map(|r| r.ok()) - .all(|u| u.server_name() != user_id.server_name()) - { - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - self.userroomid_leftstate.insert( - &userroom_id, - &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), - )?; // TODO - self.roomuserid_leftcount - .insert(&roomuser_id, &db.globals.next_count()?.to_be_bytes())?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - } - _ => {} - } - - if update_joined_count { - self.update_joined_count(room_id, db)?; - } - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, db))] - pub fn update_joined_count(&self, room_id: &RoomId, db: &Database) -> Result<()> { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut joined_servers = HashSet::new(); - let mut real_users = HashSet::new(); - - for joined in self.room_members(room_id).filter_map(|r| r.ok()) { - joined_servers.insert(joined.server_name().to_owned()); - if joined.server_name() == db.globals.server_name() - && !db.users.is_deactivated(&joined).unwrap_or(true) - { - real_users.insert(joined); - } - joinedcount += 1; - } - - for invited in self.room_members_invited(room_id).filter_map(|r| r.ok()) { - joined_servers.insert(invited.server_name().to_owned()); - invitedcount += 1; - } - - self.roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - - self.roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - - self.our_real_users_cache - .write() - .unwrap() - .insert(room_id.to_owned(), Arc::new(real_users)); - - for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id, db))] - pub fn get_our_real_users( - &self, - room_id: &RoomId, - db: &Database, - ) -> Result<Arc<HashSet<Box<UserId>>>> { - let maybe = self - .our_real_users_cache - .read() - .unwrap() - .get(room_id) - .cloned(); - if let Some(users) = maybe { - Ok(users) - } else { - self.update_joined_count(room_id, db)?; - Ok(Arc::clone( - self.our_real_users_cache - .read() - .unwrap() - .get(room_id) - .unwrap(), - )) - } - } - - #[tracing::instrument(skip(self, room_id, appservice, db))] - pub fn appservice_in_room( - &self, - room_id: &RoomId, - appservice: &(String, serde_yaml::Value), - db: &Database, - ) -> Result<bool> { - let maybe = self - .appservice_in_room_cache - .read() - .unwrap() - .get(room_id) - .and_then(|map| map.get(&appservice.0)) - .copied(); - - if let Some(b) = maybe { - Ok(b) - } else if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::<Vec<_>>() - }); - - let bridge_user_id = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, db.globals.server_name()).ok() - }); - - let in_room = bridge_user_id - .map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self.room_members(room_id).any(|userid| { - userid.map_or(false, |userid| { - users.iter().any(|r| r.is_match(userid.as_str())) - }) - }); - - self.appservice_in_room_cache - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.0.clone(), in_room); - - Ok(in_room) - } else { - Ok(false) - } - } - - // Make a user leave all their joined rooms - #[tracing::instrument(skip(self, db))] - pub async fn leave_all_rooms(&self, user_id: &UserId, db: &Database) -> Result<()> { - let all_rooms = db - .rooms - .rooms_joined(user_id) - .chain(db.rooms.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) - .collect::<Vec<_>>(); - - for room_id in all_rooms { - let room_id = match room_id { - Ok(room_id) => room_id, - Err(_) => continue, - }; - - let _ = self.leave_room(user_id, &room_id, db).await; - } - - Ok(()) - } - - #[tracing::instrument(skip(self, db))] - pub async fn leave_room( - &self, - user_id: &UserId, - room_id: &RoomId, - db: &Database, - ) -> Result<()> { - // Ask a remote server if we don't have this room - if !self.exists(room_id)? && room_id.server_name() != db.globals.server_name() { - if let Err(e) = self.remote_leave_room(user_id, room_id, db).await { - warn!("Failed to leave room {} remotely: {}", user_id, e); - // Don't tell the client about this error - } - - let last_state = self - .invite_state(user_id, room_id)? - .map_or_else(|| self.left_state(user_id, room_id), |s| Ok(Some(s)))?; - - // We always drop the invite, we can't rely on other servers - self.update_membership( - room_id, - user_id, - MembershipState::Leave, - user_id, - last_state, - db, - true, - )?; - } else { - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - let mut event: RoomMemberEventContent = serde_json::from_str( - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot leave a room you are not a member of.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - event.membership = MembershipState::Leave; - - self.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - user_id, - room_id, - db, - &state_lock, - )?; - } - - Ok(()) - } - - #[tracing::instrument(skip(self, db))] - async fn remote_leave_room( - &self, - user_id: &UserId, - room_id: &RoomId, - db: &Database, - ) -> Result<()> { - let mut make_leave_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in leaving.", - )); - - let invite_state = db - .rooms - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "User is not invited.", - ))?; - - let servers: HashSet<_> = invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); - - for remote_server in servers { - let make_leave_response = db - .sending - .send_federation_request( - &db.globals, - &remote_server, - federation::membership::prepare_leave_event::v1::Request { room_id, user_id }, - ) - .await; - - make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); - - if make_leave_response_and_server.is_ok() { - break; - } - } - - let (make_leave_response, remote_server) = make_leave_response_and_server?; - - let room_version_id = match make_leave_response.room_version { - Some(version) if self.is_supported_version(&db, &version) => version, - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - - let mut leave_event_stub = - serde_json::from_str::<CanonicalJsonObject>(make_leave_response.event.get()).map_err( - |_| Error::BadServerResponse("Invalid make_leave event json received from server."), - )?; - - // TODO: Is origin needed? - leave_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_str().to_owned()), - ); - leave_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - leave_event_stub.remove("event_id"); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut leave_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - leave_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - // It has enough fields to be called a proper event now - let leave_event = leave_event_stub; - - db.sending - .send_federation_request( - &db.globals, - &remote_server, - federation::membership::create_leave_event::v2::Request { - room_id, - event_id: &event_id, - pdu: &PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), - }, - ) - .await?; - - Ok(()) - } - - /// Makes a user forget a room. - #[tracing::instrument(skip(self))] - pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - #[tracing::instrument(skip(self, globals))] - pub fn set_alias( - &self, - alias: &RoomAliasId, - room_id: Option<&RoomId>, - globals: &super::globals::Globals, - ) -> Result<()> { - if let Some(room_id) = room_id { - // New alias - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xff); - aliasid.extend_from_slice(&globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?; - } else { - // room_id=None means remove alias - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id.to_vec(); - prefix.push(0xff); - - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - self.alias_roomid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Alias does not exist.", - )); - } - } - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn id_from_alias(&self, alias: &RoomAliasId) -> Result<Option<Box<RoomId>>> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in alias_roomid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn room_aliases<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator<Item = Result<Box<RoomAliasId>>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - }) - } - - #[tracing::instrument(skip(self))] - pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> { - if public { - self.publicroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.publicroomids.remove(room_id.as_bytes())?; - } - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) - } - - #[tracing::instrument(skip(self))] - pub fn iter_ids(&self) -> impl Iterator<Item = Result<Box<RoomId>>> + '_ { - self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - }) - } - - #[tracing::instrument(skip(self))] - pub fn public_rooms(&self) -> impl Iterator<Item = Result<Box<RoomId>>> + '_ { - self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - }) - } - - #[tracing::instrument(skip(self))] - pub fn search_pdus<'a>( - &'a self, - room_id: &RoomId, - search_string: &str, - ) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - let prefix_clone = prefix.clone(); - - let words: Vec<_> = search_string - .split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .map(str::to_lowercase) - .collect(); - - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xff); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.tokenids - .iter_from(&last_possible_id, true) // Newest pdus first - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(|(key, _)| key[key.len() - size_of::<u64>()..].to_vec()) - }); - - Ok(utils::common_elements(iterators, |a, b| { - // We compare b with a because we reversed the iterator earlier - b.cmp(a) - }) - .map(|iter| { - ( - iter.map(move |id| { - let mut pduid = prefix_clone.clone(); - pduid.extend_from_slice(&id); - pduid - }), - words, - ) - })) - } - - #[tracing::instrument(skip(self))] - pub fn get_shared_rooms<'a>( - &'a self, - users: Vec<Box<UserId>>, - ) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xff) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - + 1; // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(|r| r.ok()) - }); - - // We use the default compare function because keys are sorted correctly (not reversed) - Ok(utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { - Error::bad_database("Invalid RoomId bytes in userroomid_joined") - })?) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - })) - } - - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] - pub fn room_servers<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator<Item = Result<Box<ServerName>>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.roomserverids.scan_prefix(prefix).map(|(key, _)| { - ServerName::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Server name in roomserverids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) - }) - } - - #[tracing::instrument(skip(self))] - pub fn server_in_room<'a>(&'a self, server: &ServerName, room_id: &RoomId) -> Result<bool> { - let mut key = server.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - - self.serverroomids.get(&key).map(|o| o.is_some()) - } - - /// Returns an iterator of all rooms a server participates in (as far as we know). - #[tracing::instrument(skip(self))] - pub fn server_rooms<'a>( - &'a self, - server: &ServerName, - ) -> impl Iterator<Item = Result<Box<RoomId>>> + 'a { - let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xff); - - self.serverroomids.scan_prefix(prefix).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) - }) - } - - /// Returns an iterator over all joined members of a room. - #[tracing::instrument(skip(self))] - pub fn room_members<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator<Item = Result<Box<UserId>>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("User ID in roomuserid_joined is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) - }) - } - - #[tracing::instrument(skip(self))] - pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_joinedcount - .get(room_id.as_bytes())? - .map(|b| { - utils::u64_from_bytes(&b) - .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { - self.roomid_invitedcount - .get(room_id.as_bytes())? - .map(|b| { - utils::u64_from_bytes(&b) - .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) - }) - .transpose() - } - - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] - pub fn room_useroncejoined<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator<Item = Result<Box<UserId>>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("User ID in room_useroncejoined is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }) - } - - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] - pub fn room_members_invited<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator<Item = Result<Box<UserId>>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("User ID in roomuserid_invited is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }) - } - - #[tracing::instrument(skip(self))] - pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid invitecount in db.") - })?)) - }) - } - - #[tracing::instrument(skip(self))] - pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_leftcount - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid leftcount in db.")) - }) - .transpose() - } - - pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } - - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] - pub fn rooms_joined<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator<Item = Result<Box<RoomId>>> + 'a { - self.userroomid_joined - .scan_prefix(user_id.as_bytes().to_vec()) - .map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_joined is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - }) - } - - /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] - pub fn rooms_invited<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator<Item = Result<(Box<RoomId>, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok((room_id, state)) - }) - } - - #[tracing::instrument(skip(self))] - pub fn invite_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok(state) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn left_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok(state) - }) - .transpose() - } - - /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] - pub fn rooms_left<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator<Item = Result<(Box<RoomId>, Vec<Raw<AnySyncStateEvent>>)>> + 'a { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok((room_id, state)) - }) - } - - #[tracing::instrument(skip(self))] - pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - pub fn get_auth_chain_from_cache<'a>( - &'a self, - key: &[u64], - ) -> Result<Option<Arc<HashSet<u64>>>> { - // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); - } - - // Check DB cache - if key.len() == 1 { - if let Some(chain) = - self.shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::<u64>()) - .map(|chunk| { - utils::u64_from_bytes(chunk).expect("byte length is correct") - }) - .collect() - }) - { - let chain = Arc::new(chain); - - // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(vec![key[0]], Arc::clone(&chain)); - - return Ok(Some(chain)); - } - } - - Ok(None) - } - - #[tracing::instrument(skip(self))] - pub fn cache_auth_chain(&self, key: Vec<u64>, chain: Arc<HashSet<u64>>) -> Result<()> { - // Persist in db - if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::<Vec<u8>>(), - )?; - } - - // Cache in RAM - self.auth_chain_cache.lock().unwrap().insert(key, chain); - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn lazy_load_was_sent_before( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ll_user: &UserId, - ) -> Result<bool> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) - } - - #[tracing::instrument(skip(self))] - pub fn lazy_load_mark_sent( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - lazy_load: HashSet<Box<UserId>>, - count: u64, - ) { - self.lazy_load_waiting.lock().unwrap().insert( - ( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - count, - ), - lazy_load, - ); - } - - #[tracing::instrument(skip(self))] - pub fn lazy_load_confirm_delivery( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - since: u64, - ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); - - for ll_id in user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } - } - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn lazy_load_reset( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); - - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } - - Ok(()) - } - - /// Returns the room's version. - #[tracing::instrument(skip(self))] - pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { - let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option<RoomCreateEventContent> = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - let room_version = create_event_content - .map(|create_event| create_event.room_version) - .ok_or_else(|| Error::BadDatabase("Invalid room version"))?; - Ok(room_version) - } -} diff --git a/src/database/rooms/edus.rs b/src/database/rooms/edus.rs deleted file mode 100644 index 118efd4..0000000 --- a/src/database/rooms/edus.rs +++ /dev/null @@ -1,550 +0,0 @@ -use crate::{database::abstraction::Tree, utils, Error, Result}; -use ruma::{ - events::{ - presence::{PresenceEvent, PresenceEventContent}, - receipt::ReceiptEvent, - SyncEphemeralRoomEvent, - }, - presence::PresenceState, - serde::Raw, - signatures::CanonicalJsonObject, - RoomId, UInt, UserId, -}; -use std::{ - collections::{HashMap, HashSet}, - mem, - sync::Arc, -}; - -pub struct RoomEdus { - pub(in super::super) readreceiptid_readreceipt: Arc<dyn Tree>, // ReadReceiptId = RoomId + Count + UserId - pub(in super::super) roomuserid_privateread: Arc<dyn Tree>, // RoomUserId = Room + User, PrivateRead = Count - pub(in super::super) roomuserid_lastprivatereadupdate: Arc<dyn Tree>, // LastPrivateReadUpdate = Count - pub(in super::super) typingid_userid: Arc<dyn Tree>, // TypingId = RoomId + TimeoutTime + Count - pub(in super::super) roomid_lasttypingupdate: Arc<dyn Tree>, // LastRoomTypingUpdate = Count - pub(in super::super) presenceid_presence: Arc<dyn Tree>, // PresenceId = RoomId + Count + UserId - pub(in super::super) userid_lastpresenceupdate: Arc<dyn Tree>, // LastPresenceUpdate = Count -} - -impl RoomEdus { - /// Adds an event which will be saved until a new event replaces it (e.g. read receipt). - pub fn readreceipt_update( - &self, - user_id: &UserId, - room_id: &RoomId, - event: ReceiptEvent, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) - { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } - - let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes()); - room_latest_id.push(0xff); - room_latest_id.extend_from_slice(user_id.as_bytes()); - - self.readreceiptid_readreceipt.insert( - &room_latest_id, - &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), - )?; - - Ok(()) - } - - /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. - #[tracing::instrument(skip(self))] - pub fn readreceipts_since<'a>( - &'a self, - room_id: &RoomId, - since: u64, - ) -> impl Iterator< - Item = Result<( - Box<UserId>, - u64, - Raw<ruma::events::AnySyncEphemeralRoomEvent>, - )>, - > + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - let prefix2 = prefix.clone(); - - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since - - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count = - utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()]) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..]) - .map_err(|_| { - Error::bad_database("Invalid readreceiptid userid bytes in db.") - })?, - ) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - - let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| { - Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json.") - })?; - json.remove("room_id"); - - Ok(( - user_id, - count, - Raw::from_json( - serde_json::value::to_raw_value(&json).expect("json is valid raw value"), - ), - )) - }) - } - - /// Sets a private read marker at `count`. - #[tracing::instrument(skip(self, globals))] - pub fn private_read_set( - &self, - room_id: &RoomId, - user_id: &UserId, - count: u64, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; - - self.roomuserid_lastprivatereadupdate - .insert(&key, &globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - /// Returns the private read marker. - #[tracing::instrument(skip(self))] - pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { - Error::bad_database("Invalid private read marker bytes") - })?)) - }) - } - - /// Returns the count of the last typing update in this room. - pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") - }) - }) - .transpose()? - .unwrap_or(0)) - } - - /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is - /// called. - pub fn typing_add( - &self, - user_id: &UserId, - room_id: &RoomId, - timeout: u64, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let count = globals.next_count()?.to_be_bytes(); - - let mut room_typing_id = prefix; - room_typing_id.extend_from_slice(&timeout.to_be_bytes()); - room_typing_id.push(0xff); - room_typing_id.extend_from_slice(&count); - - self.typingid_userid - .insert(&room_typing_id, &*user_id.as_bytes())?; - - self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &count)?; - - Ok(()) - } - - /// Removes a user from typing before the timeout is reached. - pub fn typing_remove( - &self, - user_id: &UserId, - room_id: &RoomId, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let user_id = user_id.to_string(); - - let mut found_outdated = false; - - // Maybe there are multiple ones from calling roomtyping_add multiple times - for outdated_edu in self - .typingid_userid - .scan_prefix(prefix) - .filter(|(_, v)| &**v == user_id.as_bytes()) - { - self.typingid_userid.remove(&outdated_edu.0)?; - found_outdated = true; - } - - if found_outdated { - self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - } - - Ok(()) - } - - /// Makes sure that typing events with old timestamps get removed. - fn typings_maintain( - &self, - room_id: &RoomId, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let current_timestamp = utils::millis_since_unix_epoch(); - - let mut found_outdated = false; - - // Find all outdated edus before inserting a new one - for outdated_edu in self - .typingid_userid - .scan_prefix(prefix) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes( - &key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| { - Error::bad_database("RoomTyping has invalid timestamp or delimiters.") - })?[0..mem::size_of::<u64>()], - ) - .map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?, - )) - }) - .filter_map(|r| r.ok()) - .take_while(|&(_, timestamp)| timestamp < current_timestamp) - { - // This is an outdated edu (time > timestamp) - self.typingid_userid.remove(&outdated_edu.0)?; - found_outdated = true; - } - - if found_outdated { - self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - } - - Ok(()) - } - - /// Returns the count of the last typing update in this room. - #[tracing::instrument(skip(self, globals))] - pub fn last_typing_update( - &self, - room_id: &RoomId, - globals: &super::super::globals::Globals, - ) -> Result<u64> { - self.typings_maintain(room_id, globals)?; - - Ok(self - .roomid_lasttypingupdate - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") - }) - }) - .transpose()? - .unwrap_or(0)) - } - - pub fn typings_all( - &self, - room_id: &RoomId, - ) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut user_ids = HashSet::new(); - - for (_, user_id) in self.typingid_userid.scan_prefix(prefix) { - let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| { - Error::bad_database("User ID in typingid_userid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?; - - user_ids.insert(user_id); - } - - Ok(SyncEphemeralRoomEvent { - content: ruma::events::typing::TypingEventContent { - user_ids: user_ids.into_iter().collect(), - }, - }) - } - - /// Adds a presence event which will be saved until a new event replaces it. - /// - /// Note: This method takes a RoomId because presence updates are always bound to rooms to - /// make sure users outside these rooms can't see them. - pub fn update_presence( - &self, - user_id: &UserId, - room_id: &RoomId, - presence: PresenceEvent, - globals: &super::super::globals::Globals, - ) -> Result<()> { - // TODO: Remove old entry? Or maybe just wipe completely from time to time? - - let count = globals.next_count()?.to_be_bytes(); - - let mut presence_id = room_id.as_bytes().to_vec(); - presence_id.push(0xff); - presence_id.extend_from_slice(&count); - presence_id.push(0xff); - presence_id.extend_from_slice(presence.sender.as_bytes()); - - self.presenceid_presence.insert( - &presence_id, - &serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"), - )?; - - self.userid_lastpresenceupdate.insert( - user_id.as_bytes(), - &utils::millis_since_unix_epoch().to_be_bytes(), - )?; - - Ok(()) - } - - /// Resets the presence timeout, so the user will stay in their current presence state. - #[tracing::instrument(skip(self))] - pub fn ping_presence(&self, user_id: &UserId) -> Result<()> { - self.userid_lastpresenceupdate.insert( - user_id.as_bytes(), - &utils::millis_since_unix_epoch().to_be_bytes(), - )?; - - Ok(()) - } - - /// Returns the timestamp of the last presence update of this user in millis since the unix epoch. - pub fn last_presence_update(&self, user_id: &UserId) -> Result<Option<u64>> { - self.userid_lastpresenceupdate - .get(user_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.") - }) - }) - .transpose() - } - - pub fn get_last_presence_event( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result<Option<PresenceEvent>> { - let last_update = match self.last_presence_update(user_id)? { - Some(last) => last, - None => return Ok(None), - }; - - let mut presence_id = room_id.as_bytes().to_vec(); - presence_id.push(0xff); - presence_id.extend_from_slice(&last_update.to_be_bytes()); - presence_id.push(0xff); - presence_id.extend_from_slice(user_id.as_bytes()); - - self.presenceid_presence - .get(&presence_id)? - .map(|value| { - let mut presence: PresenceEvent = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Invalid presence event in db."))?; - let current_timestamp: UInt = utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"); - - if presence.content.presence == PresenceState::Online { - // Don't set last_active_ago when the user is online - presence.content.last_active_ago = None; - } else { - // Convert from timestamp to duration - presence.content.last_active_ago = presence - .content - .last_active_ago - .map(|timestamp| current_timestamp - timestamp); - } - - Ok(presence) - }) - .transpose() - } - - /// Sets all users to offline who have been quiet for too long. - fn _presence_maintain( - &self, - rooms: &super::Rooms, - globals: &super::super::globals::Globals, - ) -> Result<()> { - let current_timestamp = utils::millis_since_unix_epoch(); - - for (user_id_bytes, last_timestamp) in self - .userid_lastpresenceupdate - .iter() - .filter_map(|(k, bytes)| { - Some(( - k, - utils::u64_from_bytes(&bytes) - .map_err(|_| { - Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.") - }) - .ok()?, - )) - }) - .take_while(|(_, timestamp)| current_timestamp.saturating_sub(*timestamp) > 5 * 60_000) - // 5 Minutes - { - // Send new presence events to set the user offline - let count = globals.next_count()?.to_be_bytes(); - let user_id: Box<_> = utils::string_from_bytes(&user_id_bytes) - .map_err(|_| { - Error::bad_database("Invalid UserId bytes in userid_lastpresenceupdate.") - })? - .try_into() - .map_err(|_| Error::bad_database("Invalid UserId in userid_lastpresenceupdate."))?; - for room_id in rooms.rooms_joined(&user_id).filter_map(|r| r.ok()) { - let mut presence_id = room_id.as_bytes().to_vec(); - presence_id.push(0xff); - presence_id.extend_from_slice(&count); - presence_id.push(0xff); - presence_id.extend_from_slice(&user_id_bytes); - - self.presenceid_presence.insert( - &presence_id, - &serde_json::to_vec(&PresenceEvent { - content: PresenceEventContent { - avatar_url: None, - currently_active: None, - displayname: None, - last_active_ago: Some( - last_timestamp.try_into().expect("time is valid"), - ), - presence: PresenceState::Offline, - status_msg: None, - }, - sender: user_id.to_owned(), - }) - .expect("PresenceEvent can be serialized"), - )?; - } - - self.userid_lastpresenceupdate.insert( - user_id.as_bytes(), - &utils::millis_since_unix_epoch().to_be_bytes(), - )?; - } - - Ok(()) - } - - /// Returns an iterator over the most recent presence updates that happened after the event with id `since`. - #[tracing::instrument(skip(self, since, _rooms, _globals))] - pub fn presence_since( - &self, - room_id: &RoomId, - since: u64, - _rooms: &super::Rooms, - _globals: &super::super::globals::Globals, - ) -> Result<HashMap<Box<UserId>, PresenceEvent>> { - //self.presence_maintain(rooms, globals)?; - - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since - let mut hashmap = HashMap::new(); - - for (key, value) in self - .presenceid_presence - .iter_from(&*first_possible_edu, false) - .take_while(|(key, _)| key.starts_with(&prefix)) - { - let user_id = UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("Invalid UserId bytes in presenceid_presence."))?, - ) - .map_err(|_| Error::bad_database("Invalid UserId in presenceid_presence."))?; - - let mut presence: PresenceEvent = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Invalid presence event in db."))?; - - let current_timestamp: UInt = utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"); - - if presence.content.presence == PresenceState::Online { - // Don't set last_active_ago when the user is online - presence.content.last_active_ago = None; - } else { - // Convert from timestamp to duration - presence.content.last_active_ago = presence - .content - .last_active_ago - .map(|timestamp| current_timestamp - timestamp); - } - - hashmap.insert(user_id, presence); - } - - Ok(hashmap) - } -} diff --git a/src/database/sending.rs b/src/database/sending.rs deleted file mode 100644 index 4c830d6..0000000 --- a/src/database/sending.rs +++ /dev/null @@ -1,845 +0,0 @@ -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - fmt::Debug, - sync::Arc, - time::{Duration, Instant}, -}; - -use crate::{ - appservice_server, database::pusher, server_server, utils, Database, Error, PduEvent, Result, -}; -use federation::transactions::send_transaction_message; -use futures_util::{stream::FuturesUnordered, StreamExt}; -use ring::digest; -use ruma::{ - api::{ - appservice, - federation::{ - self, - transactions::edu::{ - DeviceListUpdateContent, Edu, ReceiptContent, ReceiptData, ReceiptMap, - }, - }, - OutgoingRequest, - }, - device_id, - events::{push_rules::PushRulesEvent, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, - push, - receipt::ReceiptType, - uint, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, -}; -use tokio::{ - select, - sync::{mpsc, RwLock, Semaphore}, -}; -use tracing::{error, warn}; - -use super::abstraction::Tree; - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum OutgoingKind { - Appservice(String), - Push(Vec<u8>, Vec<u8>), // user and pushkey - Normal(Box<ServerName>), -} - -impl OutgoingKind { - #[tracing::instrument(skip(self))] - pub fn get_prefix(&self) -> Vec<u8> { - let mut prefix = match self { - OutgoingKind::Appservice(server) => { - let mut p = b"+".to_vec(); - p.extend_from_slice(server.as_bytes()); - p - } - OutgoingKind::Push(user, pushkey) => { - let mut p = b"$".to_vec(); - p.extend_from_slice(user); - p.push(0xff); - p.extend_from_slice(pushkey); - p - } - OutgoingKind::Normal(server) => { - let mut p = Vec::new(); - p.extend_from_slice(server.as_bytes()); - p - } - }; - prefix.push(0xff); - - prefix - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum SendingEventType { - Pdu(Vec<u8>), - Edu(Vec<u8>), -} - -pub struct Sending { - /// The state for a given state hash. - pub(super) servername_educount: Arc<dyn Tree>, // EduCount: Count of last EDU sync - pub(super) servernameevent_data: Arc<dyn Tree>, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content - pub(super) servercurrentevent_data: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content - pub(super) maximum_requests: Arc<Semaphore>, - pub sender: mpsc::UnboundedSender<(Vec<u8>, Vec<u8>)>, -} - -enum TransactionStatus { - Running, - Failed(u32, Instant), // number of times failed, time of last failure - Retrying(u32), // number of times failed -} - -impl Sending { - pub fn start_handler( - &self, - db: Arc<RwLock<Database>>, - mut receiver: mpsc::UnboundedReceiver<(Vec<u8>, Vec<u8>)>, - ) { - tokio::spawn(async move { - let mut futures = FuturesUnordered::new(); - - let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new(); - - // Retry requests we could not finish yet - let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new(); - - let guard = db.read().await; - - for (key, outgoing_kind, event) in guard - .sending - .servercurrentevent_data - .iter() - .filter_map(|(key, v)| { - Self::parse_servercurrentevent(&key, v) - .ok() - .map(|(k, e)| (key, k, e)) - }) - { - let entry = initial_transactions - .entry(outgoing_kind.clone()) - .or_insert_with(Vec::new); - - if entry.len() > 30 { - warn!( - "Dropping some current events: {:?} {:?} {:?}", - key, outgoing_kind, event - ); - guard.sending.servercurrentevent_data.remove(&key).unwrap(); - continue; - } - - entry.push(event); - } - - drop(guard); - - for (outgoing_kind, events) in initial_transactions { - current_transaction_status - .insert(outgoing_kind.get_prefix(), TransactionStatus::Running); - futures.push(Self::handle_events( - outgoing_kind.clone(), - events, - Arc::clone(&db), - )); - } - - loop { - select! { - Some(response) = futures.next() => { - match response { - Ok(outgoing_kind) => { - let guard = db.read().await; - - let prefix = outgoing_kind.get_prefix(); - for (key, _) in guard.sending.servercurrentevent_data - .scan_prefix(prefix.clone()) - { - guard.sending.servercurrentevent_data.remove(&key).unwrap(); - } - - // Find events that have been added since starting the last request - let new_events: Vec<_> = guard.sending.servernameevent_data - .scan_prefix(prefix.clone()) - .filter_map(|(k, v)| { - Self::parse_servercurrentevent(&k, v).ok().map(|ev| (ev, k)) - }) - .take(30) - .collect(); - - // TODO: find edus - - if !new_events.is_empty() { - // Insert pdus we found - for (e, key) in &new_events { - let value = if let SendingEventType::Edu(value) = &e.1 { &**value } else { &[] }; - guard.sending.servercurrentevent_data.insert(key, value).unwrap(); - guard.sending.servernameevent_data.remove(key).unwrap(); - } - - drop(guard); - - futures.push( - Self::handle_events( - outgoing_kind.clone(), - new_events.into_iter().map(|(event, _)| event.1).collect(), - Arc::clone(&db), - ) - ); - } else { - current_transaction_status.remove(&prefix); - } - } - Err((outgoing_kind, _)) => { - current_transaction_status.entry(outgoing_kind.get_prefix()).and_modify(|e| *e = match e { - TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), - TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n+1, Instant::now()), - TransactionStatus::Failed(_, _) => { - error!("Request that was not even running failed?!"); - return - }, - }); - } - }; - }, - Some((key, value)) = receiver.recv() => { - if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key, value) { - let guard = db.read().await; - - if let Ok(Some(events)) = Self::select_events( - &outgoing_kind, - vec![(event, key)], - &mut current_transaction_status, - &guard - ) { - futures.push(Self::handle_events(outgoing_kind, events, Arc::clone(&db))); - } - } - } - } - } - }); - } - - #[tracing::instrument(skip(outgoing_kind, new_events, current_transaction_status, db))] - fn select_events( - outgoing_kind: &OutgoingKind, - new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key - current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>, - db: &Database, - ) -> Result<Option<Vec<SendingEventType>>> { - let mut retry = false; - let mut allow = true; - - let prefix = outgoing_kind.get_prefix(); - let entry = current_transaction_status.entry(prefix.clone()); - - entry - .and_modify(|e| match e { - TransactionStatus::Running | TransactionStatus::Retrying(_) => { - allow = false; // already running - } - TransactionStatus::Failed(tries, time) => { - // Fail if a request has failed recently (exponential backoff) - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - allow = false; - } else { - retry = true; - *e = TransactionStatus::Retrying(*tries); - } - } - }) - .or_insert(TransactionStatus::Running); - - if !allow { - return Ok(None); - } - - let mut events = Vec::new(); - - if retry { - // We retry the previous transaction - for (key, value) in db.sending.servercurrentevent_data.scan_prefix(prefix) { - if let Ok((_, e)) = Self::parse_servercurrentevent(&key, value) { - events.push(e); - } - } - } else { - for (e, full_key) in new_events { - let value = if let SendingEventType::Edu(value) = &e { - &**value - } else { - &[][..] - }; - db.sending - .servercurrentevent_data - .insert(&full_key, value)?; - - // If it was a PDU we have to unqueue it - // TODO: don't try to unqueue EDUs - db.sending.servernameevent_data.remove(&full_key)?; - - events.push(e); - } - - if let OutgoingKind::Normal(server_name) = outgoing_kind { - if let Ok((select_edus, last_count)) = Self::select_edus(db, server_name) { - events.extend(select_edus.into_iter().map(SendingEventType::Edu)); - - db.sending - .servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes())?; - } - } - } - - Ok(Some(events)) - } - - #[tracing::instrument(skip(db, server))] - pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { - // u64: count of last edu - let since = db - .sending - .servername_educount - .get(server.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - })?; - let mut events = Vec::new(); - let mut max_edu_count = since; - let mut device_list_changes = HashSet::new(); - - 'outer: for room_id in db.rooms.server_rooms(server) { - let room_id = room_id?; - // Look for device list updates in this room - device_list_changes.extend( - db.users - .keys_changed(&room_id.to_string(), since, None) - .filter_map(|r| r.ok()) - .filter(|user_id| user_id.server_name() == db.globals.server_name()), - ); - - // Look for read receipts in this room - for r in db.rooms.edus.readreceipts_since(&room_id, since) { - let (user_id, count, read_receipt) = r?; - - if count > max_edu_count { - max_edu_count = count; - } - - if user_id.server_name() != db.globals.server_name() { - continue; - } - - let event: AnySyncEphemeralRoomEvent = - serde_json::from_str(read_receipt.json().get()) - .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; - let federation_event = match event { - AnySyncEphemeralRoomEvent::Receipt(r) => { - let mut read = BTreeMap::new(); - - let (event_id, mut receipt) = r - .content - .0 - .into_iter() - .next() - .expect("we only use one event per read receipt"); - let receipt = receipt - .remove(&ReceiptType::Read) - .expect("our read receipts always set this") - .remove(&user_id) - .expect("our read receipts always have the user here"); - - read.insert( - user_id, - ReceiptData { - data: receipt.clone(), - event_ids: vec![event_id.clone()], - }, - ); - - let receipt_map = ReceiptMap { read }; - - let mut receipts = BTreeMap::new(); - receipts.insert(room_id.clone(), receipt_map); - - Edu::Receipt(ReceiptContent { receipts }) - } - _ => { - Error::bad_database("Invalid event type in read_receipts"); - continue; - } - }; - - events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); - - if events.len() >= 20 { - break 'outer; - } - } - } - - for user_id in device_list_changes { - // Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767 - // Because synapse resyncs, we can just insert dummy data - let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { - user_id, - device_id: device_id!("dummy").to_owned(), - device_display_name: Some("Dummy".to_owned()), - stream_id: uint!(1), - prev_id: Vec::new(), - deleted: None, - keys: None, - }); - - events.push(serde_json::to_vec(&edu).expect("json can be serialized")); - } - - Ok((events, max_edu_count)) - } - - #[tracing::instrument(skip(self, pdu_id, senderkey))] - pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Vec<u8>) -> Result<()> { - let mut key = b"$".to_vec(); - key.extend_from_slice(&senderkey); - key.push(0xff); - key.extend_from_slice(pdu_id); - self.servernameevent_data.insert(&key, &[])?; - self.sender.send((key, vec![])).unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(self, servers, pdu_id))] - pub fn send_pdu<I: Iterator<Item = Box<ServerName>>>( - &self, - servers: I, - pdu_id: &[u8], - ) -> Result<()> { - let mut batch = servers.map(|server| { - let mut key = server.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pdu_id); - - self.sender.send((key.clone(), vec![])).unwrap(); - - (key, Vec::new()) - }); - - self.servernameevent_data.insert_batch(&mut batch)?; - - Ok(()) - } - - #[tracing::instrument(skip(self, server, serialized))] - pub fn send_reliable_edu( - &self, - server: &ServerName, - serialized: Vec<u8>, - id: u64, - ) -> Result<()> { - let mut key = server.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&id.to_be_bytes()); - self.servernameevent_data.insert(&key, &serialized)?; - self.sender.send((key, serialized)).unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn send_pdu_appservice(&self, appservice_id: &str, pdu_id: &[u8]) -> Result<()> { - let mut key = b"+".to_vec(); - key.extend_from_slice(appservice_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(pdu_id); - self.servernameevent_data.insert(&key, &[])?; - self.sender.send((key, vec![])).unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(keys))] - fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> { - // We only hash the pdu's event ids, not the whole pdu - let bytes = keys.join(&0xff); - let hash = digest::digest(&digest::SHA256, &bytes); - hash.as_ref().to_owned() - } - - /// Cleanup event data - /// Used for instance after we remove an appservice registration - /// - #[tracing::instrument(skip(self))] - pub fn cleanup_events(&self, key_id: &str) -> Result<()> { - let mut prefix = b"+".to_vec(); - prefix.extend_from_slice(key_id.as_bytes()); - prefix.push(0xff); - - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } - - for (key, _) in self.servernameevent_data.scan_prefix(prefix.clone()) { - self.servernameevent_data.remove(&key).unwrap(); - } - - Ok(()) - } - - #[tracing::instrument(skip(db, events, kind))] - async fn handle_events( - kind: OutgoingKind, - events: Vec<SendingEventType>, - db: Arc<RwLock<Database>>, - ) -> Result<OutgoingKind, (OutgoingKind, Error)> { - let db = db.read().await; - - match &kind { - OutgoingKind::Appservice(id) => { - let mut pdu_jsons = Vec::new(); - - for event in &events { - match event { - SendingEventType::Pdu(pdu_id) => { - pdu_jsons.push(db.rooms - .get_pdu_from_id(pdu_id) - .map_err(|e| (kind.clone(), e))? - .ok_or_else(|| { - ( - kind.clone(), - Error::bad_database( - "[Appservice] Event in servernameevent_data not found in db.", - ), - ) - })? - .to_room_event()) - } - SendingEventType::Edu(_) => { - // Appservices don't need EDUs (?) - } - } - } - - let permit = db.sending.maximum_requests.acquire().await; - - let response = appservice_server::send_request( - &db.globals, - db.appservice - .get_registration(&id) - .map_err(|e| (kind.clone(), e))? - .ok_or_else(|| { - ( - kind.clone(), - Error::bad_database( - "[Appservice] Could not load registration from db.", - ), - ) - })?, - appservice::event::push_events::v1::Request { - events: &pdu_jsons, - txn_id: (&*base64::encode_config( - Self::calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, - }) - .collect::<Vec<_>>(), - ), - base64::URL_SAFE_NO_PAD, - )) - .into(), - }, - ) - .await - .map(|_response| kind.clone()) - .map_err(|e| (kind, e)); - - drop(permit); - - response - } - OutgoingKind::Push(user, pushkey) => { - let mut pdus = Vec::new(); - - for event in &events { - match event { - SendingEventType::Pdu(pdu_id) => { - pdus.push( - db.rooms - .get_pdu_from_id(pdu_id) - .map_err(|e| (kind.clone(), e))? - .ok_or_else(|| { - ( - kind.clone(), - Error::bad_database( - "[Push] Event in servernamevent_datas not found in db.", - ), - ) - })?, - ); - } - SendingEventType::Edu(_) => { - // Push gateways don't need EDUs (?) - } - } - } - - for pdu in pdus { - // Redacted events are not notification targets (we don't send push for them) - if let Some(unsigned) = &pdu.unsigned { - if let Ok(unsigned) = - serde_json::from_str::<serde_json::Value>(unsigned.get()) - { - if unsigned.get("redacted_because").is_some() { - continue; - } - } - } - - let userid = UserId::parse(utils::string_from_bytes(user).map_err(|_| { - ( - kind.clone(), - Error::bad_database("Invalid push user string in db."), - ) - })?) - .map_err(|_| { - ( - kind.clone(), - Error::bad_database("Invalid push user id in db."), - ) - })?; - - let mut senderkey = user.clone(); - senderkey.push(0xff); - senderkey.extend_from_slice(pushkey); - - let pusher = match db - .pusher - .get_pusher(&senderkey) - .map_err(|e| (OutgoingKind::Push(user.clone(), pushkey.clone()), e))? - { - Some(pusher) => pusher, - None => continue, - }; - - let rules_for_user = db - .account_data - .get( - None, - &userid, - GlobalAccountDataEventType::PushRules.to_string().into(), - ) - .unwrap_or_default() - .map(|ev: PushRulesEvent| ev.content.global) - .unwrap_or_else(|| push::Ruleset::server_default(&userid)); - - let unread: UInt = db - .rooms - .notification_count(&userid, &pdu.room_id) - .map_err(|e| (kind.clone(), e))? - .try_into() - .expect("notifiation count can't go that high"); - - let permit = db.sending.maximum_requests.acquire().await; - - let _response = pusher::send_push_notice( - &userid, - unread, - &pusher, - rules_for_user, - &pdu, - &db, - ) - .await - .map(|_response| kind.clone()) - .map_err(|e| (kind.clone(), e)); - - drop(permit); - } - Ok(OutgoingKind::Push(user.clone(), pushkey.clone())) - } - OutgoingKind::Normal(server) => { - let mut edu_jsons = Vec::new(); - let mut pdu_jsons = Vec::new(); - - for event in &events { - match event { - SendingEventType::Pdu(pdu_id) => { - // TODO: check room version and remove event_id if needed - let raw = PduEvent::convert_to_outgoing_federation_event( - db.rooms - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (OutgoingKind::Normal(server.clone()), e))? - .ok_or_else(|| { - ( - OutgoingKind::Normal(server.clone()), - Error::bad_database( - "[Normal] Event in servernamevent_datas not found in db.", - ), - ) - })?, - ); - pdu_jsons.push(raw); - } - SendingEventType::Edu(edu) => { - if let Ok(raw) = serde_json::from_slice(edu) { - edu_jsons.push(raw); - } - } - } - } - - let permit = db.sending.maximum_requests.acquire().await; - - let response = server_server::send_request( - &db.globals, - &*server, - send_transaction_message::v1::Request { - origin: db.globals.server_name(), - pdus: &pdu_jsons, - edus: &edu_jsons, - origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - transaction_id: (&*base64::encode_config( - Self::calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, - }) - .collect::<Vec<_>>(), - ), - base64::URL_SAFE_NO_PAD, - )) - .into(), - }, - ) - .await - .map(|response| { - for pdu in response.pdus { - if pdu.1.is_err() { - warn!("Failed to send to {}: {:?}", server, pdu); - } - } - kind.clone() - }) - .map_err(|e| (kind, e)); - - drop(permit); - - response - } - } - } - - #[tracing::instrument(skip(key))] - fn parse_servercurrentevent( - key: &[u8], - value: Vec<u8>, - ) -> Result<(OutgoingKind, SendingEventType)> { - // Appservices start with a plus - Ok::<_, Error>(if key.starts_with(b"+") { - let mut parts = key[1..].splitn(2, |&b| b == 0xff); - - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") - })?; - - ( - OutgoingKind::Appservice(server), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - } else if key.starts_with(b"$") { - let mut parts = key[1..].splitn(3, |&b| b == 0xff); - - let user = parts.next().expect("splitn always returns one element"); - let pushkey = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - ( - OutgoingKind::Push(user.to_vec(), pushkey.to_vec()), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - } else { - let mut parts = key.splitn(2, |&b| b == 0xff); - - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") - })?; - - ( - OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { - Error::bad_database("Invalid server string in server_currenttransaction") - })?), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - }) - } - - #[tracing::instrument(skip(self, globals, destination, request))] - pub async fn send_federation_request<T: OutgoingRequest>( - &self, - globals: &crate::database::globals::Globals, - destination: &ServerName, - request: T, - ) -> Result<T::IncomingResponse> - where - T: Debug, - { - let permit = self.maximum_requests.acquire().await; - let response = server_server::send_request(globals, destination, request).await; - drop(permit); - - response - } - - #[tracing::instrument(skip(self, globals, registration, request))] - pub async fn send_appservice_request<T: OutgoingRequest>( - &self, - globals: &crate::database::globals::Globals, - registration: serde_yaml::Value, - request: T, - ) -> Result<T::IncomingResponse> - where - T: Debug, - { - let permit = self.maximum_requests.acquire().await; - let response = appservice_server::send_request(globals, registration, request).await; - drop(permit); - - response - } -} diff --git a/src/database/uiaa.rs b/src/database/uiaa.rs deleted file mode 100644 index 1237313..0000000 --- a/src/database/uiaa.rs +++ /dev/null @@ -1,227 +0,0 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; - -use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; -use ruma::{ - api::client::{ - error::ErrorKind, - uiaa::{ - AuthType, IncomingAuthData, IncomingPassword, - IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, - }, - }, - signatures::CanonicalJsonValue, - DeviceId, UserId, -}; -use tracing::error; - -use super::abstraction::Tree; - -pub struct Uiaa { - pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: - RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>, -} - -impl Uiaa { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, - user_id: &UserId, - device_id: &DeviceId, - uiaainfo: &UiaaInfo, - json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?) - json_body, - )?; - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) - } - - pub fn try_auth( - &self, - user_id: &UserId, - device_id: &DeviceId, - auth: &IncomingAuthData, - uiaainfo: &UiaaInfo, - users: &super::users::Users, - globals: &super::globals::Globals, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth - .session() - .map(|session| self.get_uiaa_session(user_id, device_id, session)) - .unwrap_or_else(|| Ok(uiaainfo.clone()))?; - - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } - - match auth { - // Find out what the user completed - IncomingAuthData::Password(IncomingPassword { - identifier, - password, - .. - }) => { - let username = match identifier { - UserIdOrLocalpart(username) => username, - _ => { - return Err(Error::BadRequest( - ErrorKind::Unrecognized, - "Identifier type not recognized.", - )) - } - }; - - let user_id = - UserId::parse_with_server_name(username.clone(), globals.server_name()) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") - })?; - - // Check if password is correct - if let Some(hash) = users.password_hash(&user_id)? { - let hash_matches = - argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); - - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } - - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - } - IncomingAuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - } - k => error!("type not supported: {:?}", k), - } - - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } - } - // We didn't break, so this flow succeeded! - completed = true; - } - - if !completed { - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); - } - - // UIAA was successful! Remove this session and return true - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) - } - - fn set_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - - pub fn get_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Option<CanonicalJsonValue> { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(|j| j.to_owned()) - } - - fn update_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - fn get_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Result<UiaaInfo> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest( - ErrorKind::Forbidden, - "UIAA session does not exist.", - ))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } -} @@ -7,19 +7,25 @@ #![allow(clippy::suspicious_else_formatting)] #![deny(clippy::dbg_macro)] +pub mod api; mod config; mod database; -mod error; -mod pdu; -mod ruma_wrapper; +mod service; mod utils; -pub mod appservice_server; -pub mod client_server; -pub mod server_server; +use std::sync::RwLock; +pub use api::ruma_wrapper::{Ruma, RumaResponse}; pub use config::Config; -pub use database::Database; -pub use error::{Error, Result}; -pub use pdu::PduEvent; -pub use ruma_wrapper::{Ruma, RumaResponse}; +pub use database::KeyValueDatabase; +pub use service::{pdu::PduEvent, Services}; +pub use utils::error::{Error, Result}; + +pub static SERVICES: RwLock<Option<&'static Services>> = RwLock::new(None); + +pub fn services() -> &'static Services { + SERVICES + .read() + .unwrap() + .expect("SERVICES should be initialized when this is called") +} diff --git a/src/main.rs b/src/main.rs index 9a0928a..da80507 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,34 +7,40 @@ #![allow(clippy::suspicious_else_formatting)] #![deny(clippy::dbg_macro)] -use std::{future::Future, io, net::SocketAddr, sync::Arc, time::Duration}; +use std::{future::Future, io, net::SocketAddr, time::Duration}; use axum::{ - extract::{FromRequest, MatchedPath}, + extract::{DefaultBodyLimit, FromRequest, MatchedPath}, handler::Handler, response::IntoResponse, routing::{get, on, MethodFilter}, Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; +use conduit::api::{client_server, server_server}; use figment::{ providers::{Env, Format, Toml}, Figment, }; use http::{ header::{self, HeaderName}, - Method, Uri, + Method, StatusCode, Uri, }; -use opentelemetry::trace::{FutureExt, Tracer}; -use ruma::api::{client::error::ErrorKind, IncomingRequest}; -use tokio::{signal, sync::RwLock}; +use ruma::api::{ + client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, + }, + IncomingRequest, +}; +use tokio::signal; use tower::ServiceBuilder; use tower_http::{ cors::{self, CorsLayer}, trace::TraceLayer, ServiceBuilderExt as _, }; -use tracing::warn; +use tracing::{error, info, warn}; use tracing_subscriber::{prelude::*, EnvFilter}; pub use conduit::*; // Re-export everything from the library crate @@ -48,6 +54,7 @@ static GLOBAL: Jemalloc = Jemalloc; #[tokio::main] async fn main() { + // Initialize DB let raw_config = Figment::new() .merge( @@ -61,66 +68,79 @@ async fn main() { let config = match raw_config.extract::<Config>() { Ok(s) => s, Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); + eprintln!("It looks like your config is invalid. The following error occurred: {e}"); std::process::exit(1); } }; - let start = async { - config.warn_deprecated(); + config.warn_deprecated(); + + if config.allow_jaeger { + opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); + let tracer = opentelemetry_jaeger::new_agent_pipeline() + .with_auto_split_batch(true) + .with_service_name("conduit") + .install_batch(opentelemetry::runtime::Tokio) + .unwrap(); + let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - let db = match Database::load_or_create(&config).await { - Ok(db) => db, + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, Err(e) => { eprintln!( - "The database couldn't be loaded or created. The following error occured: {}", - e + "It looks like your log config is invalid. The following error occurred: {e}" ); - std::process::exit(1); + EnvFilter::try_new("warn").unwrap() } }; - run_server(&config, db).await.unwrap(); - }; - - if config.allow_jaeger { - opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); - let tracer = opentelemetry_jaeger::new_pipeline() - .install_batch(opentelemetry::runtime::Tokio) - .unwrap(); + let subscriber = tracing_subscriber::Registry::default() + .with(filter_layer) + .with(telemetry); + tracing::subscriber::set_global_default(subscriber).unwrap(); + } else if config.tracing_flame { + let registry = tracing_subscriber::Registry::default(); + let (flame_layer, _guard) = + tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); + let flame_layer = flame_layer.with_empty_samples(false); - let span = tracer.start("conduit"); - start.with_current_context().await; - drop(span); + let filter_layer = EnvFilter::new("trace,h2=off"); - println!("exporting"); - opentelemetry::global::shutdown_tracer_provider(); + let subscriber = registry.with(filter_layer).with(flame_layer); + tracing::subscriber::set_global_default(subscriber).unwrap(); } else { let registry = tracing_subscriber::Registry::default(); - if config.tracing_flame { - let (flame_layer, _guard) = - tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); - let flame_layer = flame_layer.with_empty_samples(false); - - let filter_layer = EnvFilter::new("trace,h2=off"); - - let subscriber = registry.with(filter_layer).with(flame_layer); - tracing::subscriber::set_global_default(subscriber).unwrap(); - start.await; - } else { - let fmt_layer = tracing_subscriber::fmt::Layer::new(); - let filter_layer = EnvFilter::try_new(&config.log) - .or_else(|_| EnvFilter::try_new("info")) - .unwrap(); - - let subscriber = registry.with(filter_layer).with(fmt_layer); - tracing::subscriber::set_global_default(subscriber).unwrap(); - start.await; - } + let fmt_layer = tracing_subscriber::fmt::Layer::new(); + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); + EnvFilter::try_new("warn").unwrap() + } + }; + + let subscriber = registry.with(filter_layer).with(fmt_layer); + tracing::subscriber::set_global_default(subscriber).unwrap(); + } + + info!("Loading database"); + if let Err(error) = KeyValueDatabase::load_or_create(config).await { + error!(?error, "The database couldn't be loaded or created"); + + std::process::exit(1); + }; + let config = &services().globals.config; + + info!("Starting server"); + run_server().await.unwrap(); + + if config.allow_jaeger { + opentelemetry::global::shutdown_tracer_provider(); } } -async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()> { +async fn run_server() -> io::Result<()> { + let config = &services().globals.config; let addr = SocketAddr::from((config.address, config.port)); let x_requested_with = HeaderName::from_static("x-requested-with"); @@ -139,6 +159,7 @@ async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<() }), ) .compression() + .layer(axum::middleware::from_fn(unrecognized_method)) .layer( CorsLayer::new() .allow_origin(cors::Any) @@ -158,7 +179,12 @@ async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<() ]) .max_age(Duration::from_secs(86400)), ) - .add_extension(db.clone()); + .layer(DefaultBodyLimit::max( + config + .max_request_size + .try_into() + .expect("failed to convert max request size"), + )); let app = routes().layer(middlewares).into_make_service(); let handle = ServerHandle::new(); @@ -168,19 +194,54 @@ async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<() match &config.tls { Some(tls) => { let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; - bind_rustls(addr, conf).handle(handle).serve(app).await?; + let server = bind_rustls(addr, conf).handle(handle).serve(app); + + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + + server.await? } None => { - bind(addr).handle(handle).serve(app).await?; + let server = bind(addr).handle(handle).serve(app); + + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + + server.await? } } - // After serve exits and before exiting, shutdown the DB - Database::on_shutdown(db).await; + // On shutdown + info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); + services().globals.rotate.fire(); + + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); Ok(()) } +async fn unrecognized_method<B>( + req: axum::http::Request<B>, + next: axum::middleware::Next<B>, +) -> std::result::Result<axum::response::Response, StatusCode> { + let method = req.method().clone(); + let uri = req.uri().clone(); + let inner = next.run(req).await; + if inner.status() == axum::http::StatusCode::METHOD_NOT_ALLOWED { + warn!("Method not allowed: {method} {uri}"); + return Ok(RumaResponse(UiaaResponse::MatrixError(RumaError { + body: ErrorBody::Standard { + kind: ErrorKind::Unrecognized, + message: "M_UNRECOGNIZED: Unrecognized request".to_owned(), + }, + status_code: StatusCode::METHOD_NOT_ALLOWED, + })) + .into_response()); + } + Ok(inner) +} + fn routes() -> Router { Router::new() .ruma_route(client_server::get_supported_versions_route) @@ -194,6 +255,8 @@ fn routes() -> Router { .ruma_route(client_server::change_password_route) .ruma_route(client_server::deactivate_route) .ruma_route(client_server::third_party_route) + .ruma_route(client_server::request_3pid_management_token_via_email_route) + .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) .ruma_route(client_server::get_capabilities_route) .ruma_route(client_server::get_pushrules_all_route) .ruma_route(client_server::set_pushrule_route) @@ -340,6 +403,14 @@ fn routes() -> Router { .ruma_route(server_server::get_profile_information_route) .ruma_route(server_server::get_keys_route) .ruma_route(server_server::claim_keys_route) + .route( + "/_matrix/client/r0/rooms/:room_id/initialSync", + get(initial_sync), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/initialSync", + get(initial_sync), + ) .fallback(not_found.into_service()) } @@ -372,8 +443,16 @@ async fn shutdown_signal(handle: ServerHandle) { handle.graceful_shutdown(Some(Duration::from_secs(30))); } -async fn not_found(_uri: Uri) -> impl IntoResponse { - Error::BadRequest(ErrorKind::NotFound, "Unknown or unimplemented route") +async fn not_found(uri: Uri) -> impl IntoResponse { + warn!("Not found: {uri}"); + Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") +} + +async fn initial_sync(_uri: Uri) -> impl IntoResponse { + Error::BadRequest( + ErrorKind::GuestAccessForbidden, + "Guest access not implemented", + ) } trait RouterExt { @@ -417,7 +496,7 @@ macro_rules! impl_ruma_handler { let meta = Req::METADATA; let method_filter = method_to_filter(meta.method); - for path in IntoIterator::into_iter([meta.unstable_path, meta.r0_path, meta.stable_path]).flatten() { + for path in meta.history.all_paths() { let handler = self.clone(); router = router.route(path, on(method_filter, |$( $ty: $ty, )* req| async move { @@ -442,7 +521,7 @@ impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7); impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8); fn method_to_filter(method: Method) -> MethodFilter { - let method_filter = match method { + match method { Method::DELETE => MethodFilter::DELETE, Method::GET => MethodFilter::GET, Method::HEAD => MethodFilter::HEAD, @@ -451,7 +530,6 @@ fn method_to_filter(method: Method) -> MethodFilter { Method::POST => MethodFilter::POST, Method::PUT => MethodFilter::PUT, Method::TRACE => MethodFilter::TRACE, - m => panic!("Unsupported HTTP method: {:?}", m), - }; - method_filter + m => panic!("Unsupported HTTP method: {m:?}"), + } } diff --git a/src/server_server.rs b/src/server_server.rs deleted file mode 100644 index 6fa83e4..0000000 --- a/src/server_server.rs +++ /dev/null @@ -1,3644 +0,0 @@ -use crate::{ - client_server::{self, claim_keys_helper, get_keys_helper}, - database::{rooms::CompressedStateEvent, DatabaseGuard}, - pdu::EventHash, - utils, Database, Error, PduEvent, Result, Ruma, -}; -use axum::{response::IntoResponse, Json}; -use futures_util::{stream::FuturesUnordered, StreamExt}; -use get_profile_information::v1::ProfileField; -use http::header::{HeaderValue, AUTHORIZATION}; -use regex::Regex; -use ruma::{ - api::{ - client::error::{Error as RumaError, ErrorKind}, - federation::{ - authorization::get_event_authorization, - device::get_devices::{self, v1::UserDevice}, - directory::{get_public_rooms, get_public_rooms_filtered}, - discovery::{ - get_remote_server_keys, get_remote_server_keys_batch, - get_remote_server_keys_batch::v2::QueryCriteria, get_server_keys, - get_server_version, ServerSigningKeys, VerifyKey, - }, - event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, - keys::{claim_keys, get_keys}, - membership::{ - create_invite, - create_join_event::{self, RoomState}, - prepare_join_event, - }, - query::{get_profile_information, get_room_information}, - transactions::{ - edu::{DeviceListUpdateContent, DirectDeviceContent, Edu, SigningKeyUpdateContent}, - send_transaction_message, - }, - }, - EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingResponse, - SendAccessToken, - }, - directory::{IncomingFilter, IncomingRoomNetwork}, - events::{ - receipt::{ReceiptEvent, ReceiptEventContent}, - room::{ - create::RoomCreateEventContent, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - server_acl::RoomServerAclEventContent, - }, - RoomEventType, StateEventType, - }, - int, - receipt::ReceiptType, - serde::{Base64, JsonObject, Raw}, - signatures::{CanonicalJsonObject, CanonicalJsonValue}, - state_res::{self, RoomVersion, StateMap}, - to_device::DeviceIdOrAllDevices, - uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName, - ServerSigningKeyId, -}; -use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use std::{ - collections::{btree_map, hash_map, BTreeMap, BTreeSet, HashMap, HashSet}, - fmt::Debug, - future::Future, - mem, - net::{IpAddr, SocketAddr}, - ops::Deref, - pin::Pin, - sync::{Arc, RwLock, RwLockWriteGuard}, - time::{Duration, Instant, SystemTime}, -}; -use tokio::sync::{MutexGuard, Semaphore}; -use tracing::{debug, error, info, trace, warn}; - -/// Wraps either an literal IP address plus port, or a hostname plus complement -/// (colon-plus-port if it was specified). -/// -/// Note: A `FedDest::Named` might contain an IP address in string form if there -/// was no port specified to construct a SocketAddr with. -/// -/// # Examples: -/// ```rust -/// # use conduit::server_server::FedDest; -/// # fn main() -> Result<(), std::net::AddrParseError> { -/// FedDest::Literal("198.51.100.3:8448".parse()?); -/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); -/// FedDest::Named("matrix.example.org".to_owned(), "".to_owned()); -/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); -/// FedDest::Named("198.51.100.5".to_owned(), "".to_owned()); -/// # Ok(()) -/// # } -/// ``` -#[derive(Clone, Debug, PartialEq)] -pub enum FedDest { - Literal(SocketAddr), - Named(String, String), -} - -impl FedDest { - fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{}", addr), - Self::Named(host, port) => format!("https://{}{}", host, port), - } - } - - fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, ref port) => host + port, - } - } - - fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - fn port(&self) -> Option<u16> { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } -} - -#[tracing::instrument(skip(globals, request))] -pub(crate) async fn send_request<T: OutgoingRequest>( - globals: &crate::database::globals::Globals, - destination: &ServerName, - request: T, -) -> Result<T::IncomingResponse> -where - T: Debug, -{ - if !globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let mut write_destination_to_cache = false; - - let cached_result = globals - .actual_destination_cache - .read() - .unwrap() - .get(destination) - .cloned(); - - let (actual_destination, host) = if let Some(result) = cached_result { - result - } else { - write_destination_to_cache = true; - - let result = find_actual_destination(globals, destination).await; - - (result.0, result.1.into_uri_string()) - }; - - let actual_destination_str = actual_destination.clone().into_https_string(); - - let mut http_request = request - .try_into_http_request::<Vec<u8>>( - &actual_destination_str, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!( - "Failed to find destination {}: {}", - actual_destination_str, e - ); - Error::BadServerResponse("Invalid destination") - })?; - - let mut request_map = serde_json::Map::new(); - - if !http_request.body().is_empty() { - request_map.insert( - "content".to_owned(), - serde_json::from_slice(http_request.body()) - .expect("body is valid json, we just created it"), - ); - }; - - request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); - request_map.insert( - "uri".to_owned(), - http_request - .uri() - .path_and_query() - .expect("all requests have a path") - .to_string() - .into(), - ); - request_map.insert("origin".to_owned(), globals.server_name().as_str().into()); - request_map.insert("destination".to_owned(), destination.as_str().into()); - - let mut request_json = - serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); - - ruma::signatures::sign_json( - globals.server_name().as_str(), - globals.keypair(), - &mut request_json, - ) - .expect("our request json is what ruma expects"); - - let request_json: serde_json::Map<String, serde_json::Value> = - serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); - - let signatures = request_json["signatures"] - .as_object() - .unwrap() - .values() - .map(|v| { - v.as_object() - .unwrap() - .iter() - .map(|(k, v)| (k, v.as_str().unwrap())) - }); - - for signature_server in signatures { - for s in signature_server { - http_request.headers_mut().insert( - AUTHORIZATION, - HeaderValue::from_str(&format!( - "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - globals.server_name(), - s.0, - s.1 - )) - .unwrap(), - ); - } - } - - let reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); - - let url = reqwest_request.url().clone(); - - let response = globals.federation_client().execute(reqwest_request).await; - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - warn!( - "{} {}: {}", - url, - status, - String::from_utf8_lossy(&body) - .lines() - .collect::<Vec<_>>() - .join(" ") - ); - } - - let http_response = http_response_builder - .body(body) - .expect("reqwest body is valid http body"); - - if status == 200 { - let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && write_destination_to_cache { - globals.actual_destination_cache.write().unwrap().insert( - Box::<ServerName>::from(destination), - (actual_destination, host), - ); - } - - response.map_err(|e| { - warn!( - "Invalid 200 response from {} on: {} {}", - &destination, url, e - ); - Error::BadServerResponse("Server returned bad 200 response.") - }) - } else { - Err(Error::FederationError( - destination.to_owned(), - RumaError::try_from_http_response(http_response).map_err(|e| { - warn!( - "Invalid {} response from {} on: {} {}", - status, &destination, url, e - ); - Error::BadServerResponse("Server returned bad error response.") - })?, - )) - } - } - Err(e) => Err(e.into()), - } -} - -fn get_ip_with_port(destination_str: &str) -> Option<FedDest> { - if let Ok(destination) = destination_str.parse::<SocketAddr>() { - Some(FedDest::Literal(destination)) - } else if let Ok(ip_addr) = destination_str.parse::<IpAddr>() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -fn add_port_to_hostname(destination_str: &str) -> FedDest { - let (host, port) = match destination_str.find(':') { - None => (destination_str, ":8448"), - Some(pos) => destination_str.split_at(pos), - }; - FedDest::Named(host.to_owned(), port.to_owned()) -} - -/// Returns: actual_destination, host header -/// Implemented according to the specification at https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names -/// Numbers in comments below refer to bullet points in linked section of specification -#[tracing::instrument(skip(globals))] -async fn find_actual_destination( - globals: &crate::database::globals::Globals, - destination: &'_ ServerName, -) -> (FedDest, FedDest) { - let destination_str = destination.as_str().to_owned(); - let mut hostname = destination_str.clone(); - let actual_destination = match get_ip_with_port(&destination_str) { - Some(host_port) => { - // 1: IP literal with provided or default port - host_port - } - None => { - if let Some(pos) = destination_str.find(':') { - // 2: Hostname with included port - let (host, port) = destination_str.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - match request_well_known(globals, destination.as_str()).await { - // 3: A .well-known file is available - Some(delegated_hostname) => { - hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); - match get_ip_with_port(&delegated_hostname) { - Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file - None => { - if let Some(pos) = delegated_hostname.find(':') { - // 3.2: Hostname with port in .well-known file - let (host, port) = delegated_hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - // Delegated hostname has no port in this branch - if let Some(hostname_override) = - query_srv_record(globals, &delegated_hostname).await - { - // 3.3: SRV lookup successful - let force_port = hostname_override.port(); - - if let Ok(override_ip) = globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - globals.tls_name_override.write().unwrap().insert( - delegated_hostname.clone(), - ( - override_ip.iter().collect(), - force_port.unwrap_or(8448), - ), - ); - } else { - warn!("Using SRV record, but could not resolve to IP"); - } - - if let Some(port) = force_port { - FedDest::Named(delegated_hostname, format!(":{}", port)) - } else { - add_port_to_hostname(&delegated_hostname) - } - } else { - // 3.4: No SRV records, just use the hostname from .well-known - add_port_to_hostname(&delegated_hostname) - } - } - } - } - } - // 4: No .well-known or an error occured - None => { - match query_srv_record(globals, &destination_str).await { - // 4: SRV record found - Some(hostname_override) => { - let force_port = hostname_override.port(); - - if let Ok(override_ip) = globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - globals.tls_name_override.write().unwrap().insert( - hostname.clone(), - (override_ip.iter().collect(), force_port.unwrap_or(8448)), - ); - } else { - warn!("Using SRV record, but could not resolve to IP"); - } - - if let Some(port) = force_port { - FedDest::Named(hostname.clone(), format!(":{}", port)) - } else { - add_port_to_hostname(&hostname) - } - } - // 5: No SRV record found - None => add_port_to_hostname(&destination_str), - } - } - } - } - } - }; - - // Can't use get_ip_with_port here because we don't want to add a port - // to an IP address if it wasn't specified - let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() { - FedDest::Literal(addr) - } else if let Ok(addr) = hostname.parse::<IpAddr>() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) - } else if let Some(pos) = hostname.find(':') { - let (host, port) = hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - FedDest::Named(hostname, ":8448".to_owned()) - }; - (actual_destination, hostname) -} - -#[tracing::instrument(skip(globals))] -async fn query_srv_record( - globals: &crate::database::globals::Globals, - hostname: &'_ str, -) -> Option<FedDest> { - if let Ok(Some(host_port)) = globals - .dns_resolver() - .srv_lookup(format!("_matrix._tcp.{}", hostname)) - .await - .map(|srv| { - srv.iter().next().map(|result| { - FedDest::Named( - result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), - ) - }) - }) - { - Some(host_port) - } else { - None - } -} - -#[tracing::instrument(skip(globals))] -async fn request_well_known( - globals: &crate::database::globals::Globals, - destination: &str, -) -> Option<String> { - let body: serde_json::Value = serde_json::from_str( - &globals - .default_client() - .get(&format!( - "https://{}/.well-known/matrix/server", - destination - )) - .send() - .await - .ok()? - .text() - .await - .ok()?, - ) - .ok()?; - Some(body.get("m.server")?.as_str()?.to_owned()) -} - -/// # `GET /_matrix/federation/v1/version` -/// -/// Get version information on this server. -pub async fn get_server_version_route( - db: DatabaseGuard, - _body: Ruma<get_server_version::v1::Request>, -) -> Result<get_server_version::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - Ok(get_server_version::v1::Response { - server: Some(get_server_version::v1::Server { - name: Some("Conduit".to_owned()), - version: Some(env!("CARGO_PKG_VERSION").to_owned()), - }), - }) -} - -/// # `GET /_matrix/key/v2/server` -/// -/// Gets the public signing keys of this server. -/// -/// - Matrix does not support invalidating public keys, so the key returned by this will be valid -/// forever. -// Response type for this endpoint is Json because we need to calculate a signature for the response -pub async fn get_server_keys_route(db: DatabaseGuard) -> Result<impl IntoResponse> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let mut verify_keys: BTreeMap<Box<ServerSigningKeyId>, VerifyKey> = BTreeMap::new(); - verify_keys.insert( - format!("ed25519:{}", db.globals.keypair().version()) - .try_into() - .expect("found invalid server signing keys in DB"), - VerifyKey { - key: Base64::new(db.globals.keypair().public_key().to_vec()), - }, - ); - let mut response = serde_json::from_slice( - get_server_keys::v2::Response { - server_key: Raw::new(&ServerSigningKeys { - server_name: db.globals.server_name().to_owned(), - verify_keys, - old_verify_keys: BTreeMap::new(), - signatures: BTreeMap::new(), - valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() + Duration::from_secs(86400 * 7), - ) - .expect("time is valid"), - }) - .expect("static conversion, no errors"), - } - .try_into_http_response::<Vec<u8>>() - .unwrap() - .body(), - ) - .unwrap(); - - ruma::signatures::sign_json( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut response, - ) - .unwrap(); - - Ok(Json(response)) -} - -/// # `GET /_matrix/key/v2/server/{keyId}` -/// -/// Gets the public signing keys of this server. -/// -/// - Matrix does not support invalidating public keys, so the key returned by this will be valid -/// forever. -pub async fn get_server_keys_deprecated_route(db: DatabaseGuard) -> impl IntoResponse { - get_server_keys_route(db).await -} - -/// # `POST /_matrix/federation/v1/publicRooms` -/// -/// Lists the public rooms on this server. -pub async fn get_public_rooms_filtered_route( - db: DatabaseGuard, - body: Ruma<get_public_rooms_filtered::v1::IncomingRequest>, -) -> Result<get_public_rooms_filtered::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let response = client_server::get_public_rooms_filtered_helper( - &db, - None, - body.limit, - body.since.as_deref(), - &body.filter, - &body.room_network, - ) - .await?; - - Ok(get_public_rooms_filtered::v1::Response { - chunk: response.chunk, - prev_batch: response.prev_batch, - next_batch: response.next_batch, - total_room_count_estimate: response.total_room_count_estimate, - }) -} - -/// # `GET /_matrix/federation/v1/publicRooms` -/// -/// Lists the public rooms on this server. -pub async fn get_public_rooms_route( - db: DatabaseGuard, - body: Ruma<get_public_rooms::v1::IncomingRequest>, -) -> Result<get_public_rooms::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let response = client_server::get_public_rooms_filtered_helper( - &db, - None, - body.limit, - body.since.as_deref(), - &IncomingFilter::default(), - &IncomingRoomNetwork::Matrix, - ) - .await?; - - Ok(get_public_rooms::v1::Response { - chunk: response.chunk, - prev_batch: response.prev_batch, - next_batch: response.next_batch, - total_room_count_estimate: response.total_room_count_estimate, - }) -} - -/// # `PUT /_matrix/federation/v1/send/{txnId}` -/// -/// Push EDUs and PDUs to this server. -pub async fn send_transaction_message_route( - db: DatabaseGuard, - body: Ruma<send_transaction_message::v1::IncomingRequest>, -) -> Result<send_transaction_message::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - let mut resolved_map = BTreeMap::new(); - - let pub_key_map = RwLock::new(BTreeMap::new()); - - // This is all the auth_events that have been recursively fetched so they don't have to be - // deserialized over and over again. - // TODO: make this persist across requests but not in a DB Tree (in globals?) - // TODO: This could potentially also be some sort of trie (suffix tree) like structure so - // that once an auth event is known it would know (using indexes maybe) all of the auth - // events that it references. - // let mut auth_cache = EventMap::new(); - - for pdu in &body.pdus { - // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(pdu, &db) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - continue; - } - }; - - // 0. Check the server is in the room - let room_id = match value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - { - Some(id) => id, - None => { - // Event is invalid - resolved_map.insert(event_id, Err("Event needs a valid RoomId.".to_owned())); - continue; - } - }; - - acl_check(&sender_servername, &room_id, &db)?; - - let mutex = Arc::clone( - db.globals - .roomid_mutex_federation - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let mutex_lock = mutex.lock().await; - let start_time = Instant::now(); - resolved_map.insert( - event_id.clone(), - handle_incoming_pdu( - &sender_servername, - &event_id, - &room_id, - value, - true, - &db, - &pub_key_map, - ) - .await - .map(|_| ()), - ); - drop(mutex_lock); - - let elapsed = start_time.elapsed(); - warn!( - "Handling transaction of event {} took {}m{}s", - event_id, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } - - for pdu in &resolved_map { - if let Err(e) = pdu.1 { - if e != "Room is unknown to this server." { - warn!("Incoming PDU failed {:?}", pdu); - } - } - } - - for edu in body - .edus - .iter() - .filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok()) - { - match edu { - Edu::Presence(_) => {} - Edu::Receipt(receipt) => { - for (room_id, room_updates) in receipt.receipts { - for (user_id, user_updates) in room_updates.read { - if let Some((event_id, _)) = user_updates - .event_ids - .iter() - .filter_map(|id| { - db.rooms.get_pdu_count(id).ok().flatten().map(|r| (id, r)) - }) - .max_by_key(|(_, count)| *count) - { - let mut user_receipts = BTreeMap::new(); - user_receipts.insert(user_id.clone(), user_updates.data); - - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); - - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(event_id.to_owned(), receipts); - - let event = ReceiptEvent { - content: ReceiptEventContent(receipt_content), - room_id: room_id.clone(), - }; - db.rooms.edus.readreceipt_update( - &user_id, - &room_id, - event, - &db.globals, - )?; - } else { - // TODO fetch missing events - info!("No known event ids in read receipt: {:?}", user_updates); - } - } - } - } - Edu::Typing(typing) => { - if db.rooms.is_joined(&typing.user_id, &typing.room_id)? { - if typing.typing { - db.rooms.edus.typing_add( - &typing.user_id, - &typing.room_id, - 3000 + utils::millis_since_unix_epoch(), - &db.globals, - )?; - } else { - db.rooms.edus.typing_remove( - &typing.user_id, - &typing.room_id, - &db.globals, - )?; - } - } - } - Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { - db.users - .mark_device_key_update(&user_id, &db.rooms, &db.globals)?; - } - Edu::DirectToDevice(DirectDeviceContent { - sender, - ev_type, - message_id, - messages, - }) => { - // Check if this is a new transaction id - if db - .transaction_ids - .existing_txnid(&sender, None, &message_id)? - .is_some() - { - continue; - } - - for (target_user_id, map) in &messages { - for (target_device_id_maybe, event) in map { - match target_device_id_maybe { - DeviceIdOrAllDevices::DeviceId(target_device_id) => { - db.users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - &db.globals, - )? - } - - DeviceIdOrAllDevices::AllDevices => { - for target_device_id in db.users.all_device_ids(target_user_id) { - db.users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - &db.globals, - )?; - } - } - } - } - } - - // Save transaction id with empty data - db.transaction_ids - .add_txnid(&sender, None, &message_id, &[])?; - } - Edu::SigningKeyUpdate(SigningKeyUpdateContent { - user_id, - master_key, - self_signing_key, - }) => { - if user_id.server_name() != sender_servername { - continue; - } - if let Some(master_key) = master_key { - db.users.add_cross_signing_keys( - &user_id, - &master_key, - &self_signing_key, - &None, - &db.rooms, - &db.globals, - )?; - } - } - Edu::_Custom(_) => {} - } - } - - db.flush()?; - - Ok(send_transaction_message::v1::Response { pdus: resolved_map }) -} - -/// An async function that can recursively call itself. -type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>; - -/// When receiving an event one needs to: -/// 0. Check the server is in the room -/// 1. Skip the PDU if we already know about it -/// 2. Check signatures, otherwise drop -/// 3. Check content hash, redact if doesn't match -/// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not -/// timeline events -/// 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are -/// also rejected "due to auth events" -/// 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events -/// 7. Persist this event as an outlier -/// 8. If not timeline event: stop -/// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline -/// events -/// 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities -/// doing all the checks in this list starting at 1. These are not timeline events -/// 11. Check the auth of the event passes based on the state of the event -/// 12. Ensure that the state is derived from the previous current state (i.e. we calculated by -/// doing state res where one of the inputs was a previously trusted set of state, don't just -/// trust a set of state we got from a remote) -/// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" -/// it -/// 14. Use state resolution to find new room state -// We use some AsyncRecursiveType hacks here so we can call this async funtion recursively -#[tracing::instrument(skip(value, is_timeline_event, db, pub_key_map))] -pub(crate) async fn handle_incoming_pdu<'a>( - origin: &'a ServerName, - event_id: &'a EventId, - room_id: &'a RoomId, - value: BTreeMap<String, CanonicalJsonValue>, - is_timeline_event: bool, - db: &'a Database, - pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, -) -> Result<Option<Vec<u8>>, String> { - match db.rooms.exists(room_id) { - Ok(true) => {} - _ => { - return Err("Room is unknown to this server.".to_owned()); - } - } - - match db.rooms.is_disabled(room_id) { - Ok(false) => {} - _ => { - return Err("Federation of this room is currently disabled on this server.".to_owned()); - } - } - - // 1. Skip the PDU if we already have it as a timeline event - if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(event_id) { - return Ok(Some(pdu_id.to_vec())); - } - - let create_event = db - .rooms - .room_state_get(room_id, &StateEventType::RoomCreate, "") - .map_err(|_| "Failed to ask database for event.".to_owned())? - .ok_or_else(|| "Failed to find create event in db.".to_owned())?; - - let first_pdu_in_room = db - .rooms - .first_pdu_in_room(room_id) - .map_err(|_| "Error loading first room event.".to_owned())? - .expect("Room exists"); - - let (incoming_pdu, val) = handle_outlier_pdu( - origin, - &create_event, - event_id, - room_id, - value, - db, - pub_key_map, - ) - .await?; - - // 8. if not timeline event: stop - if !is_timeline_event { - return Ok(None); - } - - if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - return Ok(None); - } - - // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - let mut graph: HashMap<Arc<EventId>, _> = HashMap::new(); - let mut eventid_info = HashMap::new(); - let mut todo_outlier_stack: Vec<Arc<EventId>> = incoming_pdu.prev_events.clone(); - - let mut amount = 0; - - while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = fetch_and_handle_outliers( - db, - origin, - &[prev_event_id.clone()], - &create_event, - room_id, - pub_key_map, - ) - .await - .pop() - { - if amount > 100 { - // Max limit reached - warn!("Max prev event limit reached!"); - graph.insert(prev_event_id.clone(), HashSet::new()); - continue; - } - - if let Some(json) = - json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) - { - if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { - amount += 1; - for prev_prev in &pdu.prev_events { - if !graph.contains_key(prev_prev) { - todo_outlier_stack.push(dbg!(prev_prev.clone())); - } - } - - graph.insert( - prev_event_id.clone(), - pdu.prev_events.iter().cloned().collect(), - ); - } else { - // Time based check failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - - eventid_info.insert(prev_event_id.clone(), (pdu, json)); - } else { - // Get json failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } else { - // Fetch and handle failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } - - let sorted = state_res::lexicographical_topological_sort(dbg!(&graph), |event_id| { - // This return value is the key used for sorting events, - // events are then sorted by power level, time, - // and lexically by event_id. - println!("{}", event_id); - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|_| "Error sorting prev events".to_owned())?; - - let mut errors = 0; - for prev_id in dbg!(sorted) { - match db.rooms.is_disabled(room_id) { - Ok(false) => {} - _ => { - return Err( - "Federation of this room is currently disabled on this server.".to_owned(), - ); - } - } - - if let Some((time, tries)) = db - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(&*prev_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", prev_id); - continue; - } - } - - if errors >= 5 { - break; - } - if let Some((pdu, json)) = eventid_info.remove(&*prev_id) { - if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - continue; - } - - let start_time = Instant::now(); - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - if let Err(e) = upgrade_outlier_to_timeline_pdu( - pdu, - json, - &create_event, - origin, - db, - room_id, - pub_key_map, - ) - .await - { - errors += 1; - warn!("Prev event {} failed: {}", prev_id, e); - match db - .globals - .bad_event_ratelimiter - .write() - .unwrap() - .entry((*prev_id).to_owned()) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1) - } - } - } - let elapsed = start_time.elapsed(); - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .remove(&room_id.to_owned()); - warn!( - "Handling prev event {} took {}m{}s", - prev_id, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } - } - - let start_time = Instant::now(); - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); - let r = upgrade_outlier_to_timeline_pdu( - incoming_pdu, - val, - &create_event, - origin, - db, - room_id, - pub_key_map, - ) - .await; - db.globals - .roomid_federationhandletime - .write() - .unwrap() - .remove(&room_id.to_owned()); - - r -} - -#[tracing::instrument(skip(create_event, value, db, pub_key_map))] -fn handle_outlier_pdu<'a>( - origin: &'a ServerName, - create_event: &'a PduEvent, - event_id: &'a EventId, - room_id: &'a RoomId, - value: BTreeMap<String, CanonicalJsonValue>, - db: &'a Database, - pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, -) -> AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>), String>> { - Box::pin(async move { - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys - fetch_required_signing_keys(&value, pub_key_map, db) - .await - .map_err(|e| e.to_string())?; - - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - "Invalid create event in db.".to_owned() - })?; - - let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); - - let mut val = match ruma::signatures::verify_event( - &*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?, - &value, - room_version_id, - ) { - Err(e) => { - // Drop - warn!("Dropping bad event {}: {}", event_id, e); - return Err("Signature verification failed".to_owned()); - } - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - warn!("Calculated hash does not match: {}", event_id); - match ruma::signatures::redact(&value, room_version_id) { - Ok(obj) => obj, - Err(_) => return Err("Redaction failed".to_owned()), - } - } - Ok(ruma::signatures::Verified::All) => value, - }; - - // Now that we have checked the signature and hashes we can add the eventID and convert - // to our PduEvent type - val.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - let incoming_pdu = serde_json::from_value::<PduEvent>( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| "Event is not a valid PDU.".to_owned())?; - - // 4. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" - // EDIT: Step 5 is not applied anymore because it failed too often - warn!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_outliers( - db, - origin, - &incoming_pdu - .auth_events - .iter() - .map(|x| Arc::from(&**x)) - .collect::<Vec<_>>(), - create_event, - room_id, - pub_key_map, - ) - .await; - - // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events - info!( - "Auth check for {} based on auth events", - incoming_pdu.event_id - ); - - // Build map of auth events - let mut auth_events = HashMap::new(); - for id in &incoming_pdu.auth_events { - let auth_event = match db.rooms.get_pdu(id).map_err(|e| e.to_string())? { - Some(e) => e, - None => { - warn!("Could not find auth event {}", id); - continue; - } - }; - - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); - } - hash_map::Entry::Occupied(_) => { - return Err( - "Auth event's type and state_key combination exists multiple times." - .to_owned(), - ) - } - } - } - - // The original create event must be in the auth events - if auth_events - .get(&(StateEventType::RoomCreate, "".to_owned())) - .map(|a| a.as_ref()) - != Some(create_event) - { - return Err("Incoming event refers to wrong create event.".to_owned()); - } - - if !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::<PduEvent>, // TODO: third party invite - |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), - ) - .map_err(|_e| "Auth check failed".to_owned())? - { - return Err("Event has failed auth check with auth events.".to_owned()); - } - - info!("Validation successful."); - - // 7. Persist the event as an outlier. - db.rooms - .add_pdu_outlier(&incoming_pdu.event_id, &val) - .map_err(|_| "Failed to add pdu as outlier.".to_owned())?; - info!("Added pdu as outlier."); - - Ok((Arc::new(incoming_pdu), val)) - }) -} - -#[tracing::instrument(skip(incoming_pdu, val, create_event, db, pub_key_map))] -async fn upgrade_outlier_to_timeline_pdu( - incoming_pdu: Arc<PduEvent>, - val: BTreeMap<String, CanonicalJsonValue>, - create_event: &PduEvent, - origin: &ServerName, - db: &Database, - room_id: &RoomId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, -) -> Result<Option<Vec<u8>>, String> { - if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { - return Ok(Some(pduid)); - } - - if db - .rooms - .is_event_soft_failed(&incoming_pdu.event_id) - .map_err(|_| "Failed to ask db for soft fail".to_owned())? - { - return Err("Event has been soft failed".into()); - } - - info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); - - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - "Invalid create event in db.".to_owned() - })?; - - let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); - - // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities - // doing all the checks in this list starting at 1. These are not timeline events. - - // TODO: if we know the prev_events of the incoming event we can avoid the request and build - // the state from a known point and resolve if > 1 prev_event - - info!("Requesting state at event"); - let mut state_at_incoming_event = None; - - if incoming_pdu.prev_events.len() == 1 { - let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = db - .rooms - .pdu_shortstatehash(prev_event) - .map_err(|_| "Failed talking to db".to_owned())?; - - let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some(db.rooms.state_full_ids(shortstatehash).await) - } else { - None - }; - - if let Some(Ok(mut state)) = state { - info!("Using cached state"); - let prev_pdu = - db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { - "Could not find prev event, but we know the state.".to_owned() - })?; - - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &prev_pdu.kind.to_string().into(), - state_key, - &db.globals, - ) - .map_err(|_| "Failed to create shortstatekey.".to_owned())?; - - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } - - state_at_incoming_event = Some(state); - } - } else { - info!("Calculating state at event using state res"); - let mut extremity_sstatehashes = HashMap::new(); - - let mut okay = true; - for prev_eventid in &incoming_pdu.prev_events { - let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) { - pdu - } else { - okay = false; - break; - }; - - let sstatehash = if let Ok(Some(s)) = db.rooms.pdu_shortstatehash(prev_eventid) { - s - } else { - okay = false; - break; - }; - - extremity_sstatehashes.insert(sstatehash, prev_event); - } - - if okay { - let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); - let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: BTreeMap<_, _> = db - .rooms - .state_full_ids(sstatehash) - .await - .map_err(|_| "Failed to ask db for room state.".to_owned())?; - - if let Some(state_key) = &prev_event.state_key { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &prev_event.kind.to_string().into(), - state_key, - &db.globals, - ) - .map_err(|_| "Failed to create shortstatekey.".to_owned())?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); - // Now it's the state after the pdu - } - - let mut state = StateMap::with_capacity(leaf_state.len()); - let mut starting_events = Vec::with_capacity(leaf_state.len()); - - for (k, id) in leaf_state { - if let Ok((ty, st_key)) = db.rooms.get_statekey_from_short(k) { - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - state.insert((ty.to_string().into(), st_key), id.clone()); - } else { - warn!("Failed to get_statekey_from_short."); - } - starting_events.push(id); - } - - auth_chain_sets.push( - get_auth_chain(room_id, starting_events, db) - .await - .map_err(|_| "Failed to load auth chain.".to_owned())? - .collect(), - ); - - fork_states.push(state); - } - - let lock = db.globals.stateres_mutex.lock(); - - let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); - } - res.ok().flatten() - }); - drop(lock); - - state_at_incoming_event = match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - &db.globals, - ) - .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; - Ok((shortstatekey, event_id)) - }) - .collect::<Result<_, String>>()?, - ), - Err(e) => { - warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); - None - } - } - } - } - - if state_at_incoming_event.is_none() { - info!("Calling /state_ids"); - // Call /state_ids to find out what the state at this pdu is. We trust the server's - // response to some extend, but we still do a lot of checks on the events - match db - .sending - .send_federation_request( - &db.globals, - origin, - get_room_state_ids::v1::Request { - room_id, - event_id: &incoming_pdu.event_id, - }, - ) - .await - { - Ok(res) => { - info!("Fetching state events at event."); - let state_vec = fetch_and_handle_outliers( - db, - origin, - &res.pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::<Vec<_>>(), - create_event, - room_id, - pub_key_map, - ) - .await; - - let mut state: BTreeMap<_, Arc<EventId>> = BTreeMap::new(); - for (pdu, _) in state_vec { - let state_key = pdu - .state_key - .clone() - .ok_or_else(|| "Found non-state pdu in state events.".to_owned())?; - - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - &state_key, - &db.globals, - ) - .map_err(|_| "Failed to create shortstatekey.".to_owned())?; - - match state.entry(shortstatekey) { - btree_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - } - btree_map::Entry::Occupied(_) => return Err( - "State event's type and state_key combination exists multiple times." - .to_owned(), - ), - } - } - - // The original create event must still be in the state - let create_shortstatekey = db - .rooms - .get_shortstatekey(&StateEventType::RoomCreate, "") - .map_err(|_| "Failed to talk to db.")? - .expect("Room exists"); - - if state.get(&create_shortstatekey).map(|id| id.as_ref()) - != Some(&create_event.event_id) - { - return Err("Incoming event refers to wrong create event.".to_owned()); - } - - state_at_incoming_event = Some(state); - } - Err(e) => { - warn!("Fetching state for event failed: {}", e); - return Err("Fetching state for event failed".into()); - } - }; - } - - let state_at_incoming_event = - state_at_incoming_event.expect("we always set this to some above"); - - info!("Starting auth check"); - // 11. Check the auth of the event passes based on the state of the event - let check_result = state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::<PduEvent>, // TODO: third party invite - |k, s| { - db.rooms - .get_shortstatekey(&k.to_string().into(), s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) - }, - ) - .map_err(|_e| "Auth check failed.".to_owned())?; - - if !check_result { - return Err("Event has failed auth check with state at the event.".into()); - } - info!("Auth check succeeded"); - - // We start looking at current room state now, so lets lock the room - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Now we calculate the set of extremities this room has after the incoming event has been - // applied. We start with the previous extremities (aka leaves) - info!("Calculating extremities"); - let mut extremities = db - .rooms - .get_pdu_leaves(room_id) - .map_err(|_| "Failed to load room leaves".to_owned())?; - - // Remove any forward extremities that are referenced by this incoming event's prev_events - for prev_event in &incoming_pdu.prev_events { - if extremities.contains(prev_event) { - extremities.remove(prev_event); - } - } - - // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(db.rooms.is_event_referenced(room_id, id), Ok(true))); - - info!("Compressing state at event"); - let state_ids_compressed = state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - db.rooms - .compress_state_event(*shortstatekey, id, &db.globals) - .map_err(|_| "Failed to compress_state_event".to_owned()) - }) - .collect::<Result<_, _>>()?; - - // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it - info!("Starting soft fail auth check"); - - let auth_events = db - .rooms - .get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - ) - .map_err(|_| "Failed to get_auth_events.".to_owned())?; - - let soft_fail = !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::<PduEvent>, - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|_e| "Auth check failed.".to_owned())?; - - if soft_fail { - append_incoming_pdu( - db, - &incoming_pdu, - val, - extremities.iter().map(Deref::deref), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .map_err(|e| { - warn!("Failed to add pdu to db: {}", e); - "Failed to add pdu to db.".to_owned() - })?; - - // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {:?}", incoming_pdu); - db.rooms - .mark_event_soft_failed(&incoming_pdu.event_id) - .map_err(|_| "Failed to set soft failed flag".to_owned())?; - return Err("Event has been soft failed".into()); - } - - if incoming_pdu.state_key.is_some() { - info!("Loading current room state ids"); - let current_sstatehash = db - .rooms - .current_shortstatehash(room_id) - .map_err(|_| "Failed to load current state hash.".to_owned())? - .expect("every room has state"); - - let current_state_ids = db - .rooms - .state_full_ids(current_sstatehash) - .await - .map_err(|_| "Failed to load room state.")?; - - info!("Preparing for stateres to derive new room state"); - let mut extremity_sstatehashes = HashMap::new(); - - info!("Loading extremities"); - for id in dbg!(&extremities) { - match db - .rooms - .get_pdu(id) - .map_err(|_| "Failed to ask db for pdu.".to_owned())? - { - Some(leaf_pdu) => { - extremity_sstatehashes.insert( - db.rooms - .pdu_shortstatehash(&leaf_pdu.event_id) - .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? - .ok_or_else(|| { - error!( - "Found extremity pdu with no statehash in db: {:?}", - leaf_pdu - ); - "Found pdu with no statehash in db.".to_owned() - })?, - leaf_pdu, - ); - } - _ => { - error!("Missing state snapshot for {:?}", id); - return Err("Missing state snapshot.".to_owned()); - } - } - } - - let mut fork_states = Vec::new(); - - // 12. Ensure that the state is derived from the previous current state (i.e. we calculated - // by doing state res where one of the inputs was a previously trusted set of state, - // don't just trust a set of state we got from a remote). - - // We do this by adding the current state to the list of fork states - extremity_sstatehashes.remove(¤t_sstatehash); - fork_states.push(current_state_ids); - - // We also add state after incoming event to the fork states - let mut state_after = state_at_incoming_event.clone(); - if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &incoming_pdu.kind.to_string().into(), - state_key, - &db.globals, - ) - .map_err(|_| "Failed to create shortstatekey.".to_owned())?; - - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); - } - fork_states.push(state_after); - - let mut update_state = false; - // 14. Use state resolution to find new room state - let new_room_state = if fork_states.is_empty() { - return Err("State is empty.".to_owned()); - } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { - info!("State resolution trivial"); - // There was only one state, so it has to be the room's current state (because that is - // always included) - fork_states[0] - .iter() - .map(|(k, id)| { - db.rooms - .compress_state_event(*k, id, &db.globals) - .map_err(|_| "Failed to compress_state_event.".to_owned()) - }) - .collect::<Result<_, _>>()? - } else { - info!("Loading auth chains"); - // We do need to force an update to this room's state - update_state = true; - - let mut auth_chain_sets = Vec::new(); - for state in &fork_states { - auth_chain_sets.push( - get_auth_chain( - room_id, - state.iter().map(|(_, id)| id.clone()).collect(), - db, - ) - .await - .map_err(|_| "Failed to load auth chain.".to_owned())? - .collect(), - ); - } - - info!("Loading fork states"); - - let fork_states: Vec<_> = fork_states - .into_iter() - .map(|map| { - map.into_iter() - .filter_map(|(k, id)| { - db.rooms - .get_statekey_from_short(k) - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) - .map_err(|e| warn!("Failed to get_statekey_from_short: {}", e)) - .ok() - }) - .collect::<StateMap<_>>() - }) - .collect(); - - info!("Resolving state"); - - let lock = db.globals.stateres_mutex.lock(); - let state = match state_res::resolve( - room_version_id, - &fork_states, - auth_chain_sets, - |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); - } - res.ok().flatten() - }, - ) { - Ok(new_state) => new_state, - Err(_) => { - return Err("State resolution failed, either an event could not be found or deserialization".into()); - } - }; - - drop(lock); - - info!("State resolution done. Compressing state"); - - state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = db - .rooms - .get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - &db.globals, - ) - .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; - db.rooms - .compress_state_event(shortstatekey, &event_id, &db.globals) - .map_err(|_| "Failed to compress state event".to_owned()) - }) - .collect::<Result<_, _>>()? - }; - - // Set the new room state to the resolved state - if update_state { - info!("Forcing new room state"); - db.rooms - .force_state(room_id, new_room_state, db) - .map_err(|_| "Failed to set new room state.".to_owned())?; - } - } - - info!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone()); - - // Now that the event has passed all auth it is added into the timeline. - // We use the `state_at_event` instead of `state_after` so we accurately - // represent the state for this event. - - let pdu_id = append_incoming_pdu( - db, - &incoming_pdu, - val, - extremities.iter().map(Deref::deref), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .map_err(|e| { - warn!("Failed to add pdu to db: {}", e); - "Failed to add pdu to db.".to_owned() - })?; - - info!("Appended incoming pdu"); - - // Event has passed all auth/stateres checks - drop(state_lock); - Ok(pdu_id) -} - -/// Find the event and auth it. Once the event is validated (steps 1 - 8) -/// it is appended to the outliers Tree. -/// -/// Returns pdu and if we fetched it over federation the raw json. -/// -/// a. Look in the main timeline (pduid_pdu tree) -/// b. Look at outlier pdu tree -/// c. Ask origin server over federation -/// d. TODO: Ask other servers over federation? -#[tracing::instrument(skip_all)] -pub(crate) fn fetch_and_handle_outliers<'a>( - db: &'a Database, - origin: &'a ServerName, - events: &'a [Arc<EventId>], - create_event: &'a PduEvent, - room_id: &'a RoomId, - pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, -) -> AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>> { - Box::pin(async move { - let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; - - let mut pdus = vec![]; - for id in events { - if let Some((time, tries)) = db.globals.bad_event_ratelimiter.read().unwrap().get(&**id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", id); - continue; - } - } - - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - continue; - } - - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::new(); - let mut events_all = HashSet::new(); - let mut i = 0; - while let Some(next_id) = todo_auth_events.pop() { - if events_all.contains(&next_id) { - continue; - } - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - - if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { - trace!("Found {} in db", id); - continue; - } - - info!("Fetching {} over federation.", next_id); - match db - .sending - .send_federation_request( - &db.globals, - origin, - get_event::v1::Request { event_id: &next_id }, - ) - .await - { - Ok(res) => { - info!("Got {} over federation", next_id); - let (calculated_event_id, value) = - match crate::pdu::gen_event_id_canonical_json(&res.pdu, &db) { - Ok(t) => t, - Err(_) => { - back_off((*next_id).to_owned()); - continue; - } - }; - - if calculated_event_id != *next_id { - warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu); - } - - if let Some(auth_events) = - value.get("auth_events").and_then(|c| c.as_array()) - { - for auth_event in auth_events { - if let Ok(auth_event) = - serde_json::from_value(auth_event.clone().into()) - { - let a: Arc<EventId> = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } - } - } else { - warn!("Auth event list invalid"); - } - - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - } - Err(_) => { - warn!("Failed to fetch event: {}", next_id); - back_off((*next_id).to_owned()); - } - } - } - - for (next_id, value) in events_in_reverse_order.iter().rev() { - match handle_outlier_pdu( - origin, - create_event, - next_id, - room_id, - value.clone(), - db, - pub_key_map, - ) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); - } - } - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()); - } - } - } - } - pdus - }) -} - -/// Search the DB for the signing keys of the given server, if we don't have them -/// fetch them from the server and save to our DB. -#[tracing::instrument(skip_all)] -pub(crate) async fn fetch_signing_keys( - db: &Database, - origin: &ServerName, - signature_ids: Vec<String>, -) -> Result<BTreeMap<String, Base64>> { - let contains_all_ids = - |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); - - let permit = db - .globals - .servername_ratelimiter - .read() - .unwrap() - .get(origin) - .map(|s| Arc::clone(s).acquire_owned()); - - let permit = match permit { - Some(p) => p, - None => { - let mut write = db.globals.servername_ratelimiter.write().unwrap(); - let s = Arc::clone( - write - .entry(origin.to_owned()) - .or_insert_with(|| Arc::new(Semaphore::new(1))), - ); - - s.acquire_owned() - } - } - .await; - - let back_off = |id| match db - .globals - .bad_signature_ratelimiter - .write() - .unwrap() - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; - - if let Some((time, tries)) = db - .globals - .bad_signature_ratelimiter - .read() - .unwrap() - .get(&signature_ids) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {:?}", signature_ids); - return Err(Error::BadServerResponse("bad signature, still backing off")); - } - } - - trace!("Loading signing keys for {}", origin); - - let mut result: BTreeMap<_, _> = db - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if contains_all_ids(&result) { - return Ok(result); - } - - debug!("Fetching signing keys for {} over federation", origin); - - if let Some(server_key) = db - .sending - .send_federation_request(&db.globals, origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - db.globals.add_signing_key(origin, server_key.clone())?; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - - for server in db.globals.trusted_servers() { - debug!("Asking {} for {}'s signing key", server, origin); - if let Some(server_keys) = db - .sending - .send_federation_request( - &db.globals, - server, - get_remote_server_keys::v2::Request::new( - origin, - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime to large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::<Vec<_>>() - }) - { - trace!("Got signing keys: {:?}", server_keys); - for k in server_keys { - db.globals.add_signing_key(origin, k.clone())?; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - - drop(permit); - - back_off(signature_ids); - - warn!("Failed to find public key for server: {}", origin); - Err(Error::BadServerResponse( - "Failed to find public key for server", - )) -} - -/// Append the incoming event setting the state snapshot to the state from the -/// server that sent the event. -#[tracing::instrument(skip_all)] -fn append_incoming_pdu<'a>( - db: &Database, - pdu: &PduEvent, - pdu_json: CanonicalJsonObject, - new_room_leaves: impl IntoIterator<Item = &'a EventId> + Clone + Debug, - state_ids_compressed: HashSet<CompressedStateEvent>, - soft_fail: bool, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room mutex -) -> Result<Option<Vec<u8>>> { - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - db.rooms.set_event_state( - &pdu.event_id, - &pdu.room_id, - state_ids_compressed, - &db.globals, - )?; - - if soft_fail { - db.rooms - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - db.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; - return Ok(None); - } - - let pdu_id = db.rooms.append_pdu(pdu, pdu_json, new_room_leaves, db)?; - - for appservice in db.appservice.all()? { - if db.rooms.appservice_in_room(&pdu.room_id, &appservice, db)? { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - continue; - } - - if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::<Vec<_>>() - }); - let aliases = namespaces - .get("aliases") - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::<Vec<_>>() - }); - let rooms = namespaces - .get("rooms") - .and_then(|rooms| rooms.as_sequence()); - - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) - || pdu.kind == RoomEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: &Regex| { - db.rooms - .room_aliases(&pdu.room_id) - .filter_map(|r| r.ok()) - .any(|room_alias| aliases.is_match(room_alias.as_str())) - }; - - if aliases.iter().any(matching_aliases) - || rooms.map_or(false, |rooms| rooms.contains(&pdu.room_id.as_str().into())) - || users.iter().any(matching_users) - { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - } - } - } - - Ok(Some(pdu_id)) -} - -#[tracing::instrument(skip(starting_events, db))] -pub(crate) async fn get_auth_chain<'a>( - room_id: &RoomId, - starting_events: Vec<Arc<EventId>>, - db: &'a Database, -) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { - const NUM_BUCKETS: usize = 50; - - let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; - - let mut i = 0; - for id in starting_events { - let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; - let bucket_id = (short % NUM_BUCKETS as u64) as usize; - buckets[bucket_id].insert((short, id.clone())); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - - let mut full_auth_chain = HashSet::new(); - - let mut hits = 0; - let mut misses = 0; - for chunk in buckets { - if chunk.is_empty() { - continue; - } - - let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = db.rooms.get_auth_chain_from_cache(&chunk_key)? { - hits += 1; - full_auth_chain.extend(cached.iter().copied()); - continue; - } - misses += 1; - - let mut chunk_cache = HashSet::new(); - let mut hits2 = 0; - let mut misses2 = 0; - let mut i = 0; - for (sevent_id, event_id) in chunk { - if let Some(cached) = db.rooms.get_auth_chain_from_cache(&[sevent_id])? { - hits2 += 1; - chunk_cache.extend(cached.iter().copied()); - } else { - misses2 += 1; - let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db)?); - db.rooms - .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; - println!( - "cache missed event {} with auth chain len {}", - event_id, - auth_chain.len() - ); - chunk_cache.extend(auth_chain.iter()); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - }; - } - println!( - "chunk missed with len {}, event hits2: {}, misses2: {}", - chunk_cache.len(), - hits2, - misses2 - ); - let chunk_cache = Arc::new(chunk_cache); - db.rooms - .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; - full_auth_chain.extend(chunk_cache.iter()); - } - - println!( - "total: {}, chunk hits: {}, misses: {}", - full_auth_chain.len(), - hits, - misses - ); - - Ok(full_auth_chain - .into_iter() - .filter_map(move |sid| db.rooms.get_eventid_from_short(sid).ok())) -} - -#[tracing::instrument(skip(event_id, db))] -fn get_auth_chain_inner( - room_id: &RoomId, - event_id: &EventId, - db: &Database, -) -> Result<HashSet<u64>> { - let mut todo = vec![Arc::from(event_id)]; - let mut found = HashSet::new(); - - while let Some(event_id) = todo.pop() { - match db.rooms.get_pdu(&event_id) { - Ok(Some(pdu)) => { - if pdu.room_id != room_id { - return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); - } - for auth_event in &pdu.auth_events { - let sauthevent = db - .rooms - .get_or_create_shorteventid(auth_event, &db.globals)?; - - if !found.contains(&sauthevent) { - found.insert(sauthevent); - todo.push(auth_event.clone()); - } - } - } - Ok(None) => { - warn!("Could not find pdu mentioned in auth events: {}", event_id); - } - Err(e) => { - warn!("Could not load event in auth chain: {} {}", event_id, e); - } - } - } - - Ok(found) -} - -/// # `GET /_matrix/federation/v1/event/{eventId}` -/// -/// Retrieves a single event from the server. -/// -/// - Only works if a user of this server is currently invited or joined the room -pub async fn get_event_route( - db: DatabaseGuard, - body: Ruma<get_event::v1::IncomingRequest>, -) -> Result<get_event::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - let event = db - .rooms - .get_pdu_json(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; - - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - - if !db.rooms.server_in_room(sender_servername, room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room", - )); - } - - Ok(get_event::v1::Response { - origin: db.globals.server_name().to_owned(), - origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: PduEvent::convert_to_outgoing_federation_event(event), - }) -} - -/// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` -/// -/// Retrieves events that the sender is missing. -pub async fn get_missing_events_route( - db: DatabaseGuard, - body: Ruma<get_missing_events::v1::IncomingRequest>, -) -> Result<get_missing_events::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room", - )); - } - - acl_check(sender_servername, &body.room_id, &db)?; - - let mut queued_events = body.latest_events.clone(); - let mut events = Vec::new(); - - let mut i = 0; - while i < queued_events.len() && events.len() < u64::from(body.limit) as usize { - if let Some(pdu) = db.rooms.get_pdu_json(&queued_events[i])? { - let room_id_str = pdu - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let event_room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - - if event_room_id != body.room_id { - warn!( - "Evil event detected: Event {} found while searching in room {}", - queued_events[i], body.room_id - ); - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Evil event detected", - )); - } - - if body.earliest_events.contains(&queued_events[i]) { - i += 1; - continue; - } - queued_events.extend_from_slice( - &serde_json::from_value::<Vec<Box<EventId>>>( - serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| { - Error::bad_database("Event in db has no prev_events field.") - })?) - .expect("canonical json is valid json value"), - ) - .map_err(|_| Error::bad_database("Invalid prev_events content in pdu in db."))?, - ); - events.push(PduEvent::convert_to_outgoing_federation_event(pdu)); - } - i += 1; - } - - Ok(get_missing_events::v1::Response { events }) -} - -/// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` -/// -/// Retrieves the auth chain for a given event. -/// -/// - This does not include the event itself -pub async fn get_event_authorization_route( - db: DatabaseGuard, - body: Ruma<get_event_authorization::v1::IncomingRequest>, -) -> Result<get_event_authorization::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room.", - )); - } - - acl_check(sender_servername, &body.room_id, &db)?; - - let event = db - .rooms - .get_pdu_json(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; - - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - - let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?; - - Ok(get_event_authorization::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| db.rooms.get_pdu_json(&id).ok()?) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - }) -} - -/// # `GET /_matrix/federation/v1/state/{roomId}` -/// -/// Retrieves the current state of the room. -pub async fn get_room_state_route( - db: DatabaseGuard, - body: Ruma<get_room_state::v1::IncomingRequest>, -) -> Result<get_room_state::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room.", - )); - } - - acl_check(sender_servername, &body.room_id, &db)?; - - let shortstatehash = db - .rooms - .pdu_shortstatehash(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Pdu state not found.", - ))?; - - let pdus = db - .rooms - .state_full_ids(shortstatehash) - .await? - .into_iter() - .map(|(_, id)| { - PduEvent::convert_to_outgoing_federation_event( - db.rooms.get_pdu_json(&id).unwrap().unwrap(), - ) - }) - .collect(); - - let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; - - Ok(get_room_state::v1::Response { - auth_chain: auth_chain_ids - .map(|id| { - db.rooms.get_pdu_json(&id).map(|maybe_json| { - PduEvent::convert_to_outgoing_federation_event(maybe_json.unwrap()) - }) - }) - .filter_map(|r| r.ok()) - .collect(), - pdus, - }) -} - -/// # `GET /_matrix/federation/v1/state_ids/{roomId}` -/// -/// Retrieves the current state of the room. -pub async fn get_room_state_ids_route( - db: DatabaseGuard, - body: Ruma<get_room_state_ids::v1::IncomingRequest>, -) -> Result<get_room_state_ids::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room.", - )); - } - - acl_check(sender_servername, &body.room_id, &db)?; - - let shortstatehash = db - .rooms - .pdu_shortstatehash(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Pdu state not found.", - ))?; - - let pdu_ids = db - .rooms - .state_full_ids(shortstatehash) - .await? - .into_iter() - .map(|(_, id)| (*id).to_owned()) - .collect(); - - let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; - - Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), - pdu_ids, - }) -} - -/// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}` -/// -/// Creates a join template. -pub async fn create_join_event_template_route( - db: DatabaseGuard, - body: Ruma<prepare_join_event::v1::IncomingRequest>, -) -> Result<prepare_join_event::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - if !db.rooms.exists(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room is unknown to this server.", - )); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - acl_check(sender_servername, &body.room_id, &db)?; - - // TODO: Conduit does not implement restricted join rules yet, we always reject - let join_rules_event = - db.rooms - .room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; - - let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; - - if let Some(join_rules_event_content) = join_rules_event_content { - if matches!( - join_rules_event_content.join_rule, - JoinRule::Restricted { .. } - ) { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Conduit does not support restricted rooms yet.", - )); - } - } - - let prev_events: Vec<_> = db - .rooms - .get_pdu_leaves(&body.room_id)? - .into_iter() - .take(20) - .collect(); - - let create_event = db - .rooms - .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option<RoomCreateEventContent> = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - - // If there was no create event yet, assume we are creating a room with the default version - // right now - let room_version_id = create_event_content - .map_or(db.globals.default_room_version(), |create_event| { - create_event.room_version - }); - let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - - if !body.ver.contains(&room_version_id) { - return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: room_version_id, - }, - "Room version not supported.", - )); - } - - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: None, - blurhash: None, - displayname: None, - is_direct: None, - membership: MembershipState::Join, - third_party_invite: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("member event is valid value"); - - let state_key = body.user_id.to_string(); - let kind = StateEventType::RoomMember; - - let auth_events = db.rooms.get_auth_events( - &body.room_id, - &kind.to_string().into(), - &body.user_id, - Some(&state_key), - &content, - )?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = BTreeMap::new(); - - if let Some(prev_pdu) = db.rooms.room_state_get(&body.room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); - unsigned.insert( - "prev_sender".to_owned(), - to_raw_value(&prev_pdu.sender).expect("UserId is valid"), - ); - } - - let pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: body.room_id.clone(), - sender: body.user_id.clone(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind: kind.to_string().into(), - content, - state_key: Some(state_key), - prev_events, - depth, - auth_events: auth_events - .iter() - .map(|(_, pdu)| pdu.event_id.clone()) - .collect(), - redacts: None, - unsigned: if unsigned.is_empty() { - None - } else { - Some(to_raw_value(&unsigned).expect("to_raw_value always works")) - }, - hashes: EventHash { - sha256: "aaa".to_owned(), - }, - signatures: None, - }; - - let auth_check = state_res::auth_check( - &room_version, - &pdu, - None::<PduEvent>, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|e| { - error!("{:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_str().to_owned()), - ); - - Ok(prepare_join_event::v1::Response { - room_version: Some(room_version_id), - event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), - }) -} - -async fn create_join_event( - db: &DatabaseGuard, - sender_servername: &ServerName, - room_id: &RoomId, - pdu: &RawJsonValue, -) -> Result<RoomState> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - if !db.rooms.exists(room_id)? { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room is unknown to this server.", - )); - } - - acl_check(sender_servername, room_id, db)?; - - // TODO: Conduit does not implement restricted join rules yet, we always reject - let join_rules_event = db - .rooms - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; - - if let Some(join_rules_event_content) = join_rules_event_content { - if matches!( - join_rules_event_content.join_rule, - JoinRule::Restricted { .. } - ) { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Conduit does not support restricted rooms yet.", - )); - } - } - - // We need to return the state prior to joining, let's keep a reference to that here - let shortstatehash = db - .rooms - .current_shortstatehash(room_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Pdu state not found.", - ))?; - - let pub_key_map = RwLock::new(BTreeMap::new()); - // let mut auth_cache = EventMap::new(); - - // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(pdu, &db) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; - - let origin: Box<ServerName> = serde_json::from_value( - serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event needs an origin field.", - ))?) - .expect("CanonicalJson is valid json value"), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - - let mutex = Arc::clone( - db.globals - .roomid_mutex_federation - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let mutex_lock = mutex.lock().await; - let pdu_id = handle_incoming_pdu(&origin, &event_id, room_id, value, true, db, &pub_key_map) - .await - .map_err(|e| { - warn!("Error while handling incoming send join PDU: {}", e); - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; - drop(mutex_lock); - - let state_ids = db.rooms.state_full_ids(shortstatehash).await?; - let auth_chain_ids = get_auth_chain( - room_id, - state_ids.iter().map(|(_, id)| id.clone()).collect(), - db, - ) - .await?; - - let servers = db - .rooms - .room_servers(room_id) - .filter_map(|r| r.ok()) - .filter(|server| &**server != db.globals.server_name()); - - db.sending.send_pdu(servers, &pdu_id)?; - - db.flush()?; - - Ok(RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| db.rooms.get_pdu_json(&id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| db.rooms.get_pdu_json(id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - }) -} - -/// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}` -/// -/// Submits a signed join event. -pub async fn create_join_event_v1_route( - db: DatabaseGuard, - body: Ruma<create_join_event::v1::IncomingRequest>, -) -> Result<create_join_event::v1::Response> { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - let room_state = create_join_event(&db, sender_servername, &body.room_id, &body.pdu).await?; - - Ok(create_join_event::v1::Response { room_state }) -} - -/// # `PUT /_matrix/federation/v2/send_join/{roomId}/{eventId}` -/// -/// Submits a signed join event. -pub async fn create_join_event_v2_route( - db: DatabaseGuard, - body: Ruma<create_join_event::v2::IncomingRequest>, -) -> Result<create_join_event::v2::Response> { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - let room_state = create_join_event(&db, sender_servername, &body.room_id, &body.pdu).await?; - - Ok(create_join_event::v2::Response { room_state }) -} - -/// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}` -/// -/// Invites a remote user to a room. -pub async fn create_invite_route( - db: DatabaseGuard, - body: Ruma<create_invite::v2::IncomingRequest>, -) -> Result<create_invite::v2::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - acl_check(sender_servername, &body.room_id, &db)?; - - if !db.rooms.is_supported_version(&db, &body.room_version) { - return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: body.room_version.clone(), - }, - "Server does not support this room version.", - )); - } - - let mut signed_event = utils::to_canonical_object(&body.event) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; - - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut signed_event, - &body.room_version, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; - - // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&signed_event, &body.room_version) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - signed_event.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.into()), - ); - - let sender: Box<_> = serde_json::from_value( - signed_event - .get("sender") - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event had no sender field.", - ))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id."))?; - - let invited_user: Box<_> = serde_json::from_value( - signed_event - .get("state_key") - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event had no state_key field.", - ))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; - - let mut invite_state = body.invite_room_state.clone(); - - let mut event: JsonObject = serde_json::from_str(body.event.get()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; - - event.insert("event_id".to_owned(), "$dummy".into()); - - let pdu: PduEvent = serde_json::from_value(event.into()).map_err(|e| { - warn!("Invalid invite event: {}", e); - Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.") - })?; - - invite_state.push(pdu.to_stripped_state_event()); - - // If the room already exists, the remote server will notify us about the join via /send - if !db.rooms.exists(&pdu.room_id)? { - db.rooms.update_membership( - &body.room_id, - &invited_user, - MembershipState::Invite, - &sender, - Some(invite_state), - &db, - true, - )?; - } - - db.flush()?; - - Ok(create_invite::v2::Response { - event: PduEvent::convert_to_outgoing_federation_event(signed_event), - }) -} - -/// # `GET /_matrix/federation/v1/user/devices/{userId}` -/// -/// Gets information on all devices of the user. -pub async fn get_devices_route( - db: DatabaseGuard, - body: Ruma<get_devices::v1::IncomingRequest>, -) -> Result<get_devices::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); - - Ok(get_devices::v1::Response { - user_id: body.user_id.clone(), - stream_id: db - .users - .get_devicelist_version(&body.user_id)? - .unwrap_or(0) - .try_into() - .expect("version will not grow that large"), - devices: db - .users - .all_devices_metadata(&body.user_id) - .filter_map(|r| r.ok()) - .filter_map(|metadata| { - Some(UserDevice { - keys: db - .users - .get_device_keys(&body.user_id, &metadata.device_id) - .ok()??, - device_id: metadata.device_id, - device_display_name: metadata.display_name, - }) - }) - .collect(), - master_key: db - .users - .get_master_key(&body.user_id, |u| u.server_name() == sender_servername)?, - self_signing_key: db - .users - .get_self_signing_key(&body.user_id, |u| u.server_name() == sender_servername)?, - }) -} - -/// # `GET /_matrix/federation/v1/query/directory` -/// -/// Resolve a room alias to a room id. -pub async fn get_room_information_route( - db: DatabaseGuard, - body: Ruma<get_room_information::v1::IncomingRequest>, -) -> Result<get_room_information::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let room_id = db - .rooms - .id_from_alias(&body.room_alias)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Room alias not found.", - ))?; - - Ok(get_room_information::v1::Response { - room_id, - servers: vec![db.globals.server_name().to_owned()], - }) -} - -/// # `GET /_matrix/federation/v1/query/profile` -/// -/// Gets information on a profile. -pub async fn get_profile_information_route( - db: DatabaseGuard, - body: Ruma<get_profile_information::v1::IncomingRequest>, -) -> Result<get_profile_information::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let mut displayname = None; - let mut avatar_url = None; - let mut blurhash = None; - - match &body.field { - Some(ProfileField::DisplayName) => displayname = db.users.displayname(&body.user_id)?, - Some(ProfileField::AvatarUrl) => { - avatar_url = db.users.avatar_url(&body.user_id)?; - blurhash = db.users.blurhash(&body.user_id)? - } - // TODO: what to do with custom - Some(_) => {} - None => { - displayname = db.users.displayname(&body.user_id)?; - avatar_url = db.users.avatar_url(&body.user_id)?; - blurhash = db.users.blurhash(&body.user_id)?; - } - } - - Ok(get_profile_information::v1::Response { - blurhash, - displayname, - avatar_url, - }) -} - -/// # `POST /_matrix/federation/v1/user/keys/query` -/// -/// Gets devices and identity keys for the given users. -pub async fn get_keys_route( - db: DatabaseGuard, - body: Ruma<get_keys::v1::Request>, -) -> Result<get_keys::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let result = get_keys_helper( - None, - &body.device_keys, - |u| Some(u.server_name()) == body.sender_servername.as_deref(), - &db, - ) - .await?; - - db.flush()?; - - Ok(get_keys::v1::Response { - device_keys: result.device_keys, - master_keys: result.master_keys, - self_signing_keys: result.self_signing_keys, - }) -} - -/// # `POST /_matrix/federation/v1/user/keys/claim` -/// -/// Claims one-time keys. -pub async fn claim_keys_route( - db: DatabaseGuard, - body: Ruma<claim_keys::v1::Request>, -) -> Result<claim_keys::v1::Response> { - if !db.globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } - - let result = claim_keys_helper(&body.one_time_keys, &db).await?; - - db.flush()?; - - Ok(claim_keys::v1::Response { - one_time_keys: result.one_time_keys, - }) -} - -#[tracing::instrument(skip_all)] -pub(crate) async fn fetch_required_signing_keys( - event: &BTreeMap<String, CanonicalJsonValue>, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, -) -> Result<()> { - let signatures = event - .get("signatures") - .ok_or(Error::BadServerResponse( - "No signatures in server response pdu.", - ))? - .as_object() - .ok_or(Error::BadServerResponse( - "Invalid signatures object in server response pdu.", - ))?; - - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); - - let fetch_res = fetch_signing_keys( - db, - signature_server.as_str().try_into().map_err(|_| { - Error::BadServerResponse("Invalid servername in signatures of server response pdu.") - })?, - signature_ids, - ) - .await; - - let keys = match fetch_res { - Ok(keys) => keys, - Err(_) => { - warn!("Signature verification failed: Could not fetch signing key.",); - continue; - } - }; - - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(signature_server.clone(), keys); - } - - Ok(()) -} - -// Gets a list of servers for which we don't have the signing key yet. We go over -// the PDUs and either cache the key or add it to the list that needs to be retrieved. -fn get_server_keys_from_cache( - pdu: &RawJsonValue, - servers: &mut BTreeMap<Box<ServerName>, BTreeMap<Box<ServerSigningKeyId>, QueryCriteria>>, - room_version: &RoomVersionId, - pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, -) -> Result<()> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&value, room_version) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); - - if let Some((time, tries)) = db - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } - - let signatures = value - .get("signatures") - .ok_or(Error::BadServerResponse( - "No signatures in server response pdu.", - ))? - .as_object() - .ok_or(Error::BadServerResponse( - "Invalid signatures object in server response pdu.", - ))?; - - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); - - let contains_all_ids = - |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); - - let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|_| { - Error::BadServerResponse("Invalid servername in signatures of server response pdu.") - })?; - - if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { - continue; - } - - trace!("Loading signing keys for {}", origin); - - let result: BTreeMap<_, _> = db - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if !contains_all_ids(&result) { - trace!("Signing key not loaded for {}", origin); - servers.insert(origin.to_owned(), BTreeMap::new()); - } - - pub_key_map.insert(origin.to_string(), result); - } - - Ok(()) -} - -pub(crate) async fn fetch_join_signing_keys( - event: &create_join_event::v2::Response, - room_version: &RoomVersionId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, -) -> Result<()> { - let mut servers: BTreeMap<Box<ServerName>, BTreeMap<Box<ServerSigningKeyId>, QueryCriteria>> = - BTreeMap::new(); - - { - let mut pkm = pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?; - - // Try to fetch keys, failure is okay - // Servers we couldn't find in the cache will be added to `servers` - for pdu in &event.room_state.state { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm, db); - } - for pdu in &event.room_state.auth_chain { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm, db); - } - - drop(pkm); - } - - if servers.is_empty() { - // We had all keys locally - return Ok(()); - } - - for server in db.globals.trusted_servers() { - trace!("Asking batch signing keys from trusted server {}", server); - if let Ok(keys) = db - .sending - .send_federation_request( - &db.globals, - server, - get_remote_server_keys_batch::v2::Request { - server_keys: servers.clone(), - }, - ) - .await - { - trace!("Got signing keys: {:?}", keys); - let mut pkm = pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?; - for k in keys.server_keys { - let k = k.deserialize().unwrap(); - - // TODO: Check signature from trusted server? - servers.remove(&k.server_name); - - let result = db - .globals - .add_signing_key(&k.server_name, k.clone())? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect::<BTreeMap<_, _>>(); - - pkm.insert(k.server_name.to_string(), result); - } - } - - if servers.is_empty() { - return Ok(()); - } - } - - let mut futures: FuturesUnordered<_> = servers - .into_iter() - .map(|(server, _)| async move { - ( - db.sending - .send_federation_request( - &db.globals, - &server, - get_server_keys::v2::Request::new(), - ) - .await, - server, - ) - }) - .collect(); - - while let Some(result) = futures.next().await { - if let (Ok(get_keys_response), origin) = result { - let result: BTreeMap<_, _> = db - .globals - .add_signing_key(&origin, get_keys_response.server_key.deserialize().unwrap())? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(origin.to_string(), result); - } - } - - Ok(()) -} - -/// Returns Ok if the acl allows the server -fn acl_check(server_name: &ServerName, room_id: &RoomId, db: &Database) -> Result<()> { - let acl_event = match db - .rooms - .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? - { - Some(acl) => acl, - None => return Ok(()), - }; - - let acl_event_content: RoomServerAclEventContent = - match serde_json::from_str(acl_event.content.get()) { - Ok(content) => content, - Err(_) => { - warn!("Invalid ACL event"); - return Ok(()); - } - }; - - if acl_event_content.is_allowed(server_name) { - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server was denied by ACL", - )) - } -} - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ) - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ) - } -} diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs new file mode 100644 index 0000000..c7c9298 --- /dev/null +++ b/src/service/account_data/data.rs @@ -0,0 +1,35 @@ +use std::collections::HashMap; + +use crate::Result; +use ruma::{ + events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + serde::Raw, + RoomId, UserId, +}; + +pub trait Data: Send + Sync { + /// Places one event in the account data of the user and removes the previous entry. + fn update( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + event_type: RoomAccountDataEventType, + data: &serde_json::Value, + ) -> Result<()>; + + /// Searches the account data for a specific kind. + fn get( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + kind: RoomAccountDataEventType, + ) -> Result<Option<Box<serde_json::value::RawValue>>>; + + /// Returns all changes to the account data that happened after `since`. + fn changes_since( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + since: u64, + ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>; +} diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs new file mode 100644 index 0000000..f9c49b1 --- /dev/null +++ b/src/service/account_data/mod.rs @@ -0,0 +1,53 @@ +mod data; + +pub use data::Data; + +use ruma::{ + events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + serde::Raw, + RoomId, UserId, +}; + +use std::collections::HashMap; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Places one event in the account data of the user and removes the previous entry. + #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] + pub fn update( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + event_type: RoomAccountDataEventType, + data: &serde_json::Value, + ) -> Result<()> { + self.db.update(room_id, user_id, event_type, data) + } + + /// Searches the account data for a specific kind. + #[tracing::instrument(skip(self, room_id, user_id, event_type))] + pub fn get( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + event_type: RoomAccountDataEventType, + ) -> Result<Option<Box<serde_json::value::RawValue>>> { + self.db.get(room_id, user_id, event_type) + } + + /// Returns all changes to the account data that happened after `since`. + #[tracing::instrument(skip(self, room_id, user_id, since))] + pub fn changes_since( + &self, + room_id: Option<&RoomId>, + user_id: &UserId, + since: u64, + ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { + self.db.changes_since(room_id, user_id, since) + } +} diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs new file mode 100644 index 0000000..77f351a --- /dev/null +++ b/src/service/admin/mod.rs @@ -0,0 +1,1173 @@ +use std::{ + collections::BTreeMap, + convert::{TryFrom, TryInto}, + sync::Arc, + time::Instant, +}; + +use clap::Parser; +use regex::Regex; +use ruma::{ + events::{ + room::{ + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + message::RoomMessageEventContent, + name::RoomNameEventContent, + power_levels::RoomPowerLevelsEventContent, + topic::RoomTopicEventContent, + }, + RoomEventType, + }, + EventId, OwnedRoomAliasId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, +}; +use serde_json::value::to_raw_value; +use tokio::sync::{mpsc, Mutex, MutexGuard}; + +use crate::{ + api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH}, + services, + utils::{self, HtmlEscape}, + Error, PduEvent, Result, +}; + +use super::pdu::PduBuilder; + +#[cfg_attr(test, derive(Debug))] +#[derive(Parser)] +#[command(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))] +enum AdminCommand { + #[command(verbatim_doc_comment)] + /// Register an appservice using its registration YAML + /// + /// This command needs a YAML generated by an appservice (such as a bridge), + /// which must be provided in a Markdown code-block below the command. + /// + /// Registering a new bridge using the ID of an existing bridge will replace + /// the old one. + /// + /// [commandbody] + /// # ``` + /// # yaml content here + /// # ``` + RegisterAppservice, + + /// Unregister an appservice using its ID + /// + /// You can find the ID using the `list-appservices` command. + UnregisterAppservice { + /// The appservice to unregister + appservice_identifier: String, + }, + + /// List all the currently registered appservices + ListAppservices, + + /// List all rooms the server knows about + ListRooms, + + /// List users in the database + ListLocalUsers, + + /// List all rooms we are currently handling an incoming pdu from + IncomingFederation, + + /// Deactivate a user + /// + /// User will not be removed from all rooms by default. + /// Use --leave-rooms to force the user to leave all rooms + DeactivateUser { + #[arg(short, long)] + leave_rooms: bool, + user_id: Box<UserId>, + }, + + #[command(verbatim_doc_comment)] + /// Deactivate a list of users + /// + /// Recommended to use in conjunction with list-local-users. + /// + /// Users will not be removed from joined rooms by default. + /// Can be overridden with --leave-rooms flag. + /// Removing a mass amount of users from a room may cause a significant amount of leave events. + /// The time to leave rooms may depend significantly on joined rooms and servers. + /// + /// [commandbody] + /// # ``` + /// # User list here + /// # ``` + DeactivateAll { + #[arg(short, long)] + /// Remove users from their joined rooms + leave_rooms: bool, + #[arg(short, long)] + /// Also deactivate admin accounts + force: bool, + }, + + /// Get the auth_chain of a PDU + GetAuthChain { + /// An event ID (the $ character followed by the base64 reference hash) + event_id: Box<EventId>, + }, + + #[command(verbatim_doc_comment)] + /// Parse and print a PDU from a JSON + /// + /// The PDU event is only checked for validity and is not added to the + /// database. + /// + /// [commandbody] + /// # ``` + /// # PDU json content here + /// # ``` + ParsePdu, + + /// Retrieve and print a PDU by ID from the Conduit database + GetPdu { + /// An event ID (a $ followed by the base64 reference hash) + event_id: Box<EventId>, + }, + + /// Print database memory usage statistics + DatabaseMemoryUsage, + + /// Show configuration values + ShowConfig, + + /// Reset user password + ResetPassword { + /// Username of the user for whom the password should be reset + username: String, + }, + + /// Create a new user + CreateUser { + /// Username of the new user + username: String, + /// Password of the new user, if unspecified one is generated + password: Option<String>, + }, + + /// Disables incoming federation handling for a room. + DisableRoom { room_id: Box<RoomId> }, + /// Enables incoming federation handling for a room again. + EnableRoom { room_id: Box<RoomId> }, +} + +#[derive(Debug)] +pub enum AdminRoomEvent { + ProcessMessage(String), + SendMessage(RoomMessageEventContent), +} + +pub struct Service { + pub sender: mpsc::UnboundedSender<AdminRoomEvent>, + receiver: Mutex<mpsc::UnboundedReceiver<AdminRoomEvent>>, +} + +impl Service { + pub fn build() -> Arc<Self> { + let (sender, receiver) = mpsc::unbounded_channel(); + Arc::new(Self { + sender, + receiver: Mutex::new(receiver), + }) + } + + pub fn start_handler(self: &Arc<Self>) { + let self2 = Arc::clone(self); + tokio::spawn(async move { + self2.handler().await; + }); + } + + async fn handler(&self) { + let mut receiver = self.receiver.lock().await; + // TODO: Use futures when we have long admin commands + //let mut futures = FuturesUnordered::new(); + + let conduit_user = UserId::parse(format!("@conduit:{}", services().globals.server_name())) + .expect("@conduit:server_name is valid"); + + let conduit_room = services() + .rooms + .alias + .resolve_local_alias( + format!("#admins:{}", services().globals.server_name()) + .as_str() + .try_into() + .expect("#admins:server_name is a valid room alias"), + ) + .expect("Database data for admin room alias must be valid") + .expect("Admin room must exist"); + + let send_message = |message: RoomMessageEventContent, mutex_lock: &MutexGuard<'_, ()>| { + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMessage, + content: to_raw_value(&message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + mutex_lock, + ) + .unwrap(); + }; + + loop { + tokio::select! { + Some(event) = receiver.recv() => { + let message_content = match event { + AdminRoomEvent::SendMessage(content) => content, + AdminRoomEvent::ProcessMessage(room_message) => self.process_admin_message(room_message).await + }; + + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(conduit_room.to_owned()) + .or_default(), + ); + + let state_lock = mutex_state.lock().await; + + send_message(message_content, &state_lock); + + drop(state_lock); + } + } + } + } + + pub fn process_message(&self, room_message: String) { + self.sender + .send(AdminRoomEvent::ProcessMessage(room_message)) + .unwrap(); + } + + pub fn send_message(&self, message_content: RoomMessageEventContent) { + self.sender + .send(AdminRoomEvent::SendMessage(message_content)) + .unwrap(); + } + + // Parse and process a message from the admin room + async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent { + let mut lines = room_message.lines(); + let command_line = lines.next().expect("each string has at least one line"); + let body: Vec<_> = lines.collect(); + + let admin_command = match self.parse_admin_command(command_line) { + Ok(command) => command, + Err(error) => { + let server_name = services().globals.server_name(); + let message = error.replace("server.name", server_name.as_str()); + let html_message = self.usage_to_html(&message, server_name); + + return RoomMessageEventContent::text_html(message, html_message); + } + }; + + match self.process_admin_command(admin_command, body).await { + Ok(reply_message) => reply_message, + Err(error) => { + let markdown_message = format!( + "Encountered an error while handling the command:\n\ + ```\n{error}\n```", + ); + let html_message = format!( + "Encountered an error while handling the command:\n\ + <pre>\n{error}\n</pre>", + ); + + RoomMessageEventContent::text_html(markdown_message, html_message) + } + } + } + + // Parse chat messages from the admin room into an AdminCommand object + fn parse_admin_command(&self, command_line: &str) -> std::result::Result<AdminCommand, String> { + // Note: argv[0] is `@conduit:servername:`, which is treated as the main command + let mut argv: Vec<_> = command_line.split_whitespace().collect(); + + // Replace `help command` with `command --help` + // Clap has a help subcommand, but it omits the long help description. + if argv.len() > 1 && argv[1] == "help" { + argv.remove(1); + argv.push("--help"); + } + + // Backwards compatibility with `register_appservice`-style commands + let command_with_dashes; + if argv.len() > 1 && argv[1].contains('_') { + command_with_dashes = argv[1].replace('_', "-"); + argv[1] = &command_with_dashes; + } + + AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) + } + + async fn process_admin_command( + &self, + command: AdminCommand, + body: Vec<&str>, + ) -> Result<RoomMessageEventContent> { + let reply_message_content = match command { + AdminCommand::RegisterAppservice => { + if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + { + let appservice_config = body[1..body.len() - 1].join("\n"); + let parsed_config = + serde_yaml::from_str::<serde_yaml::Value>(&appservice_config); + match parsed_config { + Ok(yaml) => match services().appservice.register_appservice(yaml) { + Ok(id) => RoomMessageEventContent::text_plain(format!( + "Appservice registered with ID: {id}." + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Failed to register appservice: {e}" + )), + }, + Err(e) => RoomMessageEventContent::text_plain(format!( + "Could not parse appservice config: {e}" + )), + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + } + AdminCommand::UnregisterAppservice { + appservice_identifier, + } => match services() + .appservice + .unregister_appservice(&appservice_identifier) + { + Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Failed to unregister appservice: {e}" + )), + }, + AdminCommand::ListAppservices => { + if let Ok(appservices) = services() + .appservice + .iter_ids() + .map(|ids| ids.collect::<Vec<_>>()) + { + let count = appservices.len(); + let output = format!( + "Appservices ({}): {}", + count, + appservices + .into_iter() + .filter_map(|r| r.ok()) + .collect::<Vec<_>>() + .join(", ") + ); + RoomMessageEventContent::text_plain(output) + } else { + RoomMessageEventContent::text_plain("Failed to get appservices.") + } + } + AdminCommand::ListRooms => { + let room_ids = services().rooms.metadata.iter_ids(); + let output = format!( + "Rooms:\n{}", + room_ids + .filter_map(|r| r.ok()) + .map(|id| id.to_string() + + "\tMembers: " + + &services() + .rooms + .state_cache + .room_joined_count(&id) + .ok() + .flatten() + .unwrap_or(0) + .to_string()) + .collect::<Vec<_>>() + .join("\n") + ); + RoomMessageEventContent::text_plain(output) + } + AdminCommand::ListLocalUsers => match services().users.list_local_users() { + Ok(users) => { + let mut msg: String = format!("Found {} local user account(s):\n", users.len()); + msg += &users.join("\n"); + RoomMessageEventContent::text_plain(&msg) + } + Err(e) => RoomMessageEventContent::text_plain(e.to_string()), + }, + AdminCommand::IncomingFederation => { + let map = services() + .globals + .roomid_federationhandletime + .read() + .unwrap(); + let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); + + for (r, (e, i)) in map.iter() { + let elapsed = i.elapsed(); + msg += &format!( + "{} {}: {}m{}s\n", + r, + e, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } + RoomMessageEventContent::text_plain(&msg) + } + AdminCommand::GetAuthChain { event_id } => { + let event_id = Arc::<EventId>::from(event_id); + if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { + Error::bad_database("Invalid room id field in event in database") + })?; + let start = Instant::now(); + let count = services() + .rooms + .auth_chain + .get_auth_chain(room_id, vec![event_id]) + .await? + .count(); + let elapsed = start.elapsed(); + RoomMessageEventContent::text_plain(format!( + "Loaded auth chain with length {count} in {elapsed:?}" + )) + } else { + RoomMessageEventContent::text_plain("Event not found.") + } + } + AdminCommand::ParsePdu => { + if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + { + let string = body[1..body.len() - 1].join("\n"); + match serde_json::from_str(&string) { + Ok(value) => { + match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { + Ok(hash) => { + let event_id = EventId::parse(format!("${hash}")); + + match serde_json::from_value::<PduEvent>( + serde_json::to_value(value).expect("value is json"), + ) { + Ok(pdu) => RoomMessageEventContent::text_plain(format!( + "EventId: {event_id:?}\n{pdu:#?}" + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "EventId: {event_id:?}\nCould not parse event: {e}" + )), + } + } + Err(e) => RoomMessageEventContent::text_plain(format!( + "Could not parse PDU JSON: {e:?}" + )), + } + } + Err(e) => RoomMessageEventContent::text_plain(format!( + "Invalid json in command body: {e}" + )), + } + } else { + RoomMessageEventContent::text_plain("Expected code block in command body.") + } + } + AdminCommand::GetPdu { event_id } => { + let mut outlier = false; + let mut pdu_json = services() + .rooms + .timeline + .get_non_outlier_pdu_json(&event_id)?; + if pdu_json.is_none() { + outlier = true; + pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; + } + match pdu_json { + Some(json) => { + let json_text = serde_json::to_string_pretty(&json) + .expect("canonical json is valid json"); + RoomMessageEventContent::text_html( + format!( + "{}\n```json\n{}\n```", + if outlier { + "PDU is outlier" + } else { + "PDU was accepted" + }, + json_text + ), + format!( + "<p>{}</p>\n<pre><code class=\"language-json\">{}\n</code></pre>\n", + if outlier { + "PDU is outlier" + } else { + "PDU was accepted" + }, + HtmlEscape(&json_text) + ), + ) + } + None => RoomMessageEventContent::text_plain("PDU not found."), + } + } + AdminCommand::DatabaseMemoryUsage => match services().globals.db.memory_usage() { + Ok(response) => RoomMessageEventContent::text_plain(response), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Failed to get database memory usage: {e}" + )), + }, + AdminCommand::ShowConfig => { + // Construct and send the response + RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) + } + AdminCommand::ResetPassword { username } => { + let user_id = match UserId::parse_with_server_name( + username.as_str().to_lowercase(), + services().globals.server_name(), + ) { + Ok(id) => id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "The supplied username is not a valid username: {e}" + ))) + } + }; + + // Check if the specified user is valid + if !services().users.exists(&user_id)? + || services().users.is_deactivated(&user_id)? + || user_id + == UserId::parse_with_server_name( + "conduit", + services().globals.server_name(), + ) + .expect("conduit user exists") + { + return Ok(RoomMessageEventContent::text_plain( + "The specified user does not exist or is deactivated!", + )); + } + + let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); + + match services() + .users + .set_password(&user_id, Some(new_password.as_str())) + { + Ok(()) => RoomMessageEventContent::text_plain(format!( + "Successfully reset the password for user {user_id}: {new_password}" + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Couldn't reset the password for user {user_id}: {e}" + )), + } + } + AdminCommand::CreateUser { username, password } => { + let password = + password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); + // Validate user id + let user_id = match UserId::parse_with_server_name( + username.as_str().to_lowercase(), + services().globals.server_name(), + ) { + Ok(id) => id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "The supplied username is not a valid username: {e}" + ))) + } + }; + if user_id.is_historical() { + return Ok(RoomMessageEventContent::text_plain(format!( + "userid {user_id} is not allowed due to historical" + ))); + } + if services().users.exists(&user_id)? { + return Ok(RoomMessageEventContent::text_plain(format!( + "userid {user_id} already exists" + ))); + } + // Create user + services().users.create(&user_id, Some(password.as_str()))?; + + // Default to pretty displayname + let mut displayname = user_id.localpart().to_owned(); + + // If enabled append lightning bolt to display name (default true) + if services().globals.enable_lightning_bolt() { + displayname.push_str(" ⚡️"); + } + + services() + .users + .set_displayname(&user_id, Some(displayname))?; + + // Initial account data + services().account_data.update( + None, + &user_id, + ruma::events::GlobalAccountDataEventType::PushRules + .to_string() + .into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: ruma::push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json value always works"), + )?; + + // we dont add a device since we're not the user, just the creator + + // Inhibit login does not work for guests + RoomMessageEventContent::text_plain(format!( + "Created user with user_id: {user_id} and password: {password}" + )) + } + AdminCommand::DisableRoom { room_id } => { + services().rooms.metadata.disable_room(&room_id, true)?; + RoomMessageEventContent::text_plain("Room disabled.") + } + AdminCommand::EnableRoom { room_id } => { + services().rooms.metadata.disable_room(&room_id, false)?; + RoomMessageEventContent::text_plain("Room enabled.") + } + AdminCommand::DeactivateUser { + leave_rooms, + user_id, + } => { + let user_id = Arc::<UserId>::from(user_id); + if services().users.exists(&user_id)? { + RoomMessageEventContent::text_plain(format!( + "Making {user_id} leave all rooms before deactivation..." + )); + + services().users.deactivate_account(&user_id)?; + + if leave_rooms { + leave_all_rooms(&user_id).await?; + } + + RoomMessageEventContent::text_plain(format!( + "User {user_id} has been deactivated" + )) + } else { + RoomMessageEventContent::text_plain(format!( + "User {user_id} doesn't exist on this server" + )) + } + } + AdminCommand::DeactivateAll { leave_rooms, force } => { + if body.len() > 2 && body[0].trim() == "```" && body.last().unwrap().trim() == "```" + { + let usernames = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>(); + + let mut user_ids: Vec<&UserId> = Vec::new(); + + for &username in &usernames { + match <&UserId>::try_from(username) { + Ok(user_id) => user_ids.push(user_id), + Err(_) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "{username} is not a valid username" + ))) + } + } + } + + let mut deactivation_count = 0; + let mut admins = Vec::new(); + + if !force { + user_ids.retain(|&user_id| match services().users.is_admin(user_id) { + Ok(is_admin) => match is_admin { + true => { + admins.push(user_id.localpart()); + false + } + false => true, + }, + Err(_) => false, + }) + } + + for &user_id in &user_ids { + if services().users.deactivate_account(user_id).is_ok() { + deactivation_count += 1 + } + } + + if leave_rooms { + for &user_id in &user_ids { + let _ = leave_all_rooms(user_id).await; + } + } + + if admins.is_empty() { + RoomMessageEventContent::text_plain(format!( + "Deactivated {deactivation_count} accounts." + )) + } else { + RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", "))) + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + } + }; + + Ok(reply_message_content) + } + + // Utility to turn clap's `--help` text to HTML. + fn usage_to_html(&self, text: &str, server_name: &ServerName) -> String { + // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: subcmdname` + let text = text.replace( + &format!("@conduit:{server_name}:-"), + &format!("@conduit:{server_name}: "), + ); + + // For the conduit admin room, subcommands become main commands + let text = text.replace("SUBCOMMAND", "COMMAND"); + let text = text.replace("subcommand", "command"); + + // Escape option names (e.g. `<element-id>`) since they look like HTML tags + let text = text.replace('<', "<").replace('>', ">"); + + // Italicize the first line (command name and version text) + let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "<em>$1</em>\n"); + + // Unmerge wrapped lines + let text = text.replace("\n ", " "); + + // Wrap option names in backticks. The lines look like: + // -V, --version Prints version information + // And are converted to: + // <code>-V, --version</code>: Prints version information + // (?m) enables multi-line mode for ^ and $ + let re = Regex::new("(?m)^ (([a-zA-Z_&;-]+(, )?)+) +(.*)$") + .expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "<code>$1</code>: $4"); + + // Look for a `[commandbody]` tag. If it exists, use all lines below it that + // start with a `#` in the USAGE section. + let mut text_lines: Vec<&str> = text.lines().collect(); + let mut command_body = String::new(); + + if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { + text_lines.remove(line_index); + + while text_lines + .get(line_index) + .map(|line| line.starts_with('#')) + .unwrap_or(false) + { + command_body += if text_lines[line_index].starts_with("# ") { + &text_lines[line_index][2..] + } else { + &text_lines[line_index][1..] + }; + command_body += "[nobr]\n"; + text_lines.remove(line_index); + } + } + + let text = text_lines.join("\n"); + + // Improve the usage section + let text = if command_body.is_empty() { + // Wrap the usage line in code tags + let re = Regex::new("(?m)^USAGE:\n (@conduit:.*)$") + .expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n<code>$1</code>").to_string() + } else { + // Wrap the usage line in a code block, and add a yaml block example + // This makes the usage of e.g. `register-appservice` more accurate + let re = Regex::new("(?m)^USAGE:\n (.*?)\n\n") + .expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n<pre>$1[nobr]\n[commandbodyblock]</pre>") + .replace("[commandbodyblock]", &command_body) + }; + + // Add HTML line-breaks + + text.replace("\n\n\n", "\n\n") + .replace('\n', "<br>\n") + .replace("[nobr]<br>", "") + } + + /// Create the admin room. + /// + /// Users in this room are considered admins by conduit, and the room can be + /// used to issue admin commands by talking to the server user inside it. + pub(crate) async fn create_admin_room(&self) -> Result<()> { + let room_id = RoomId::new(services().globals.server_name()); + + services().rooms.short.get_or_create_shortroomid(&room_id)?; + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Create a user for the server + let conduit_user = + UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + services().users.create(&conduit_user, None)?; + + let mut content = RoomCreateEventContent::new(conduit_user.clone()); + content.federate = true; + content.predecessor = None; + content.room_version = services().globals.default_room_version(); + + // 1. The room create event + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomCreate, + content: to_raw_value(&content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 2. Make conduit bot join + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(conduit_user.to_string()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 3. Power levels + let mut users = BTreeMap::new(); + users.insert(conduit_user.clone(), 100.into()); + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 4.1 Join Rules + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomJoinRules, + content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 4.2 History Visibility + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomHistoryVisibility, + content: to_raw_value(&RoomHistoryVisibilityEventContent::new( + HistoryVisibility::Shared, + )) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 4.3 Guest Access + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomGuestAccess, + content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 5. Events implied by name and topic + let room_name = format!("{} Admin Room", services().globals.server_name()); + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomName, + content: to_raw_value(&RoomNameEventContent::new(Some(room_name))) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomTopic, + content: to_raw_value(&RoomTopicEventContent { + topic: format!("Manage {}", services().globals.server_name()), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // 6. Room alias + let alias: OwnedRoomAliasId = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomCanonicalAlias, + content: to_raw_value(&RoomCanonicalAliasEventContent { + alias: Some(alias.clone()), + alt_aliases: Vec::new(), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + services().rooms.alias.set_alias(&alias, &room_id)?; + + Ok(()) + } + + /// Invite the user to the conduit admin room. + /// + /// In conduit, this is equivalent to granting admin privileges. + pub(crate) async fn make_user_admin( + &self, + user_id: &UserId, + displayname: String, + ) -> Result<()> { + let admin_room_alias: Box<RoomAliasId> = + format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + let room_id = services() + .rooms + .alias + .resolve_local_alias(&admin_room_alias)? + .expect("Admin room must exist"); + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Use the server user to grant the new admin's power level + let conduit_user = + UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + // Invite and join the real user + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: Some(displayname), + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + user_id, + &room_id, + &state_lock, + )?; + + // Set power level + let mut users = BTreeMap::new(); + users.insert(conduit_user.to_owned(), 100.into()); + users.insert(user_id.to_owned(), 100.into()); + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + // Send welcome message + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMessage, + content: to_raw_value(&RoomMessageEventContent::text_html( + format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", services().globals.server_name()), + format!("<h2>Thank you for trying out Conduit!</h2>\n<p>Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Website: https://conduit.rs<br>Git and Documentation: https://gitlab.com/famedly/conduit<br>Report issues: https://gitlab.com/famedly/conduit/-/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>Conduit room (Ask questions and get notified on updates):<br><code>/join #conduit:fachschaften.org</code></p>\n<p>Conduit lounge (Off-topic, only Conduit users are allowed to join)<br><code>/join #conduit-lounge:conduit.rs</code></p>\n", services().globals.server_name()), + )) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + )?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn get_help_short() { + get_help_inner("-h"); + } + + #[test] + fn get_help_long() { + get_help_inner("--help"); + } + + #[test] + fn get_help_subcommand() { + get_help_inner("help"); + } + + fn get_help_inner(input: &str) { + let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) + .unwrap_err() + .to_string(); + + // Search for a handful of keywords that suggest the help printed properly + assert!(error.contains("Usage:")); + assert!(error.contains("Commands:")); + assert!(error.contains("Options:")); + } +} diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs new file mode 100644 index 0000000..744f0f9 --- /dev/null +++ b/src/service/appservice/data.rs @@ -0,0 +1,19 @@ +use crate::Result; + +pub trait Data: Send + Sync { + /// Registers an appservice and returns the ID to the caller + fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String>; + + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + fn unregister_appservice(&self, service_name: &str) -> Result<()>; + + fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>>; + + fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>; + + fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>>; +} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs new file mode 100644 index 0000000..3052964 --- /dev/null +++ b/src/service/appservice/mod.rs @@ -0,0 +1,37 @@ +mod data; + +pub use data::Data; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Registers an appservice and returns the ID to the caller + pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> { + self.db.register_appservice(yaml) + } + + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + self.db.unregister_appservice(service_name) + } + + pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> { + self.db.get_registration(id) + } + + pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { + self.db.iter_ids() + } + + pub fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> { + self.db.all() + } +} diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs new file mode 100644 index 0000000..04371a0 --- /dev/null +++ b/src/service/globals/data.rs @@ -0,0 +1,34 @@ +use std::collections::BTreeMap; + +use async_trait::async_trait; +use ruma::{ + api::federation::discovery::{ServerSigningKeys, VerifyKey}, + signatures::Ed25519KeyPair, + DeviceId, OwnedServerSigningKeyId, ServerName, UserId, +}; + +use crate::Result; + +#[async_trait] +pub trait Data: Send + Sync { + fn next_count(&self) -> Result<u64>; + fn current_count(&self) -> Result<u64>; + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + fn cleanup(&self) -> Result<()>; + fn memory_usage(&self) -> Result<String>; + fn load_keypair(&self) -> Result<Ed25519KeyPair>; + fn remove_keypair(&self) -> Result<()>; + fn add_signing_key( + &self, + origin: &ServerName, + new_keys: ServerSigningKeys, + ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>; + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + fn signing_keys_for( + &self, + origin: &ServerName, + ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>; + fn database_version(&self) -> Result<u64>; + fn bump_database_version(&self, new_version: u64) -> Result<()>; +} diff --git a/src/database/globals.rs b/src/service/globals/mod.rs index 7d7b7fd..bb823e2 100644 --- a/src/database/globals.rs +++ b/src/service/globals/mod.rs @@ -1,11 +1,18 @@ -use crate::{database::Config, server_server::FedDest, utils, Error, Result}; +mod data; +pub use data::Data; +use ruma::{ + OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, +}; + +use crate::api::server_server::FedDest; + +use crate::{Config, Error, Result}; use ruma::{ api::{ client::sync::sync_events, federation::discovery::{ServerSigningKeys, VerifyKey}, }, - DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName, - ServerSigningKeyId, UserId, + DeviceId, RoomVersionId, ServerName, UserId, }; use std::{ collections::{BTreeMap, HashMap}, @@ -20,11 +27,7 @@ use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; use tracing::error; use trust_dns_resolver::TokioAsyncResolver; -use super::abstraction::Tree; - -pub const COUNTER: &[u8] = b"c"; - -type WellKnownMap = HashMap<Box<ServerName>, (FedDest, String)>; +type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>; type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type SyncHandle = ( @@ -32,27 +35,27 @@ type SyncHandle = ( Receiver<Option<Result<sync_events::v3::Response>>>, // rx ); -pub struct Globals { +pub struct Service { + pub db: &'static dyn Data, + pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host pub tls_name_override: Arc<RwLock<TlsNameMap>>, - pub(super) globals: Arc<dyn Tree>, pub config: Config, keypair: Arc<ruma::signatures::Ed25519KeyPair>, dns_resolver: TokioAsyncResolver, - jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, + jwt_decoding_key: Option<jsonwebtoken::DecodingKey>, federation_client: reqwest::Client, default_client: reqwest::Client, pub stable_room_versions: Vec<RoomVersionId>, pub unstable_room_versions: Vec<RoomVersionId>, - pub(super) server_signingkeys: Arc<dyn Tree>, - pub bad_event_ratelimiter: Arc<RwLock<HashMap<Box<EventId>, RateLimitState>>>, + pub bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>, pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>, - pub servername_ratelimiter: Arc<RwLock<HashMap<Box<ServerName>, Arc<Semaphore>>>>, - pub sync_receivers: RwLock<HashMap<(Box<UserId>, Box<DeviceId>), SyncHandle>>, - pub roomid_mutex_insert: RwLock<HashMap<Box<RoomId>, Arc<Mutex<()>>>>, - pub roomid_mutex_state: RwLock<HashMap<Box<RoomId>, Arc<TokioMutex<()>>>>, - pub roomid_mutex_federation: RwLock<HashMap<Box<RoomId>, Arc<TokioMutex<()>>>>, // this lock will be held longer - pub roomid_federationhandletime: RwLock<HashMap<Box<RoomId>, (Box<EventId>, Instant)>>, + pub servername_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>, + pub sync_receivers: RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>, + pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, + pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>, + pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>, // this lock will be held longer + pub roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>, pub stateres_mutex: Arc<Mutex<()>>, pub rotate: RotationHandler, } @@ -87,47 +90,15 @@ impl Default for RotationHandler { } } -impl Globals { - pub fn load( - globals: Arc<dyn Tree>, - server_signingkeys: Arc<dyn Tree>, - config: Config, - ) -> Result<Self> { - let keypair_bytes = globals.get(b"keypair")?.map_or_else( - || { - let keypair = utils::generate_keypair(); - globals.insert(b"keypair", &keypair)?; - Ok::<_, Error>(keypair) - }, - |s| Ok(s.to_vec()), - )?; - - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); - - let keypair = utils::string_from_bytes( - // 1. version - parts - .next() - .expect("splitn always returns at least one element"), - ) - .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) - .and_then(|version| { - // 2. key - parts - .next() - .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) - .map(|key| (version, key)) - }) - .and_then(|(version, key)| { - ruma::signatures::Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) - }); +impl Service { + pub fn load(db: &'static dyn Data, config: Config) -> Result<Self> { + let keypair = db.load_keypair(); let keypair = match keypair { Ok(k) => k, Err(e) => { error!("Keypair invalid. Deleting..."); - globals.remove(b"keypair")?; + db.remove_keypair()?; return Err(e); } }; @@ -137,7 +108,7 @@ impl Globals { let jwt_decoding_key = config .jwt_secret .as_ref() - .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static()); + .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); let default_client = reqwest_client_builder(&config)?.build()?; let name_override = Arc::clone(&tls_name_override); @@ -156,12 +127,13 @@ impl Globals { RoomVersionId::V7, RoomVersionId::V8, RoomVersionId::V9, + RoomVersionId::V10, ]; // Experimental, partially supported room versions let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5]; let mut s = Self { - globals, + db, config, keypair: Arc::new(keypair), dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { @@ -175,7 +147,6 @@ impl Globals { tls_name_override, federation_client, default_client, - server_signingkeys, jwt_decoding_key, stable_room_versions, unstable_room_versions, @@ -197,8 +168,8 @@ impl Globals { .supported_room_versions() .contains(&s.config.default_room_version) { - error!("Room version in config isn't supported, falling back to Version 6"); - s.config.default_room_version = RoomVersionId::V6; + error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); + s.config.default_room_version = crate::config::default_default_room_version(); }; Ok(s) @@ -223,16 +194,24 @@ impl Globals { #[tracing::instrument(skip(self))] pub fn next_count(&self) -> Result<u64> { - utils::u64_from_bytes(&self.globals.increment(COUNTER)?) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) + self.db.next_count() } #[tracing::instrument(skip(self))] pub fn current_count(&self) -> Result<u64> { - self.globals.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) - }) + self.db.current_count() + } + + pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.watch(user_id, device_id).await + } + + pub fn cleanup(&self) -> Result<()> { + self.db.cleanup() + } + + pub fn memory_usage(&self) -> Result<String> { + self.db.memory_usage() } pub fn server_name(&self) -> &ServerName { @@ -243,6 +222,10 @@ impl Globals { self.config.max_request_size } + pub fn max_fetch_prev_events(&self) -> u16 { + self.config.max_fetch_prev_events + } + pub fn allow_registration(&self) -> bool { self.config.allow_registration } @@ -271,7 +254,7 @@ impl Globals { self.config.enable_lightning_bolt } - pub fn trusted_servers(&self) -> &[Box<ServerName>] { + pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers } @@ -279,7 +262,7 @@ impl Globals { &self.dns_resolver } - pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey<'_>> { + pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } @@ -324,75 +307,24 @@ impl Globals { &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> { - // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; - - let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); - - let ServerSigningKeys { - verify_keys, - old_verify_keys, - .. - } = new_keys; - - keys.verify_keys.extend(verify_keys.into_iter()); - keys.old_verify_keys.extend(old_verify_keys.into_iter()); - - self.server_signingkeys.insert( - origin.as_bytes(), - &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; - - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - - Ok(tree) + ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { + self.db.add_signing_key(origin, new_keys) } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. pub fn signing_keys_for( &self, origin: &ServerName, - ) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()) - .map(|keys: ServerSigningKeys| { - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - tree - }) - .unwrap_or_else(BTreeMap::new); - - Ok(signingkeys) + ) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { + self.db.signing_keys_for(origin) } pub fn database_version(&self) -> Result<u64> { - self.globals.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version) - .map_err(|_| Error::bad_database("Database version id is invalid.")) - }) + self.db.database_version() } pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.globals - .insert(b"version", &new_version.to_be_bytes())?; - Ok(()) + self.db.bump_database_version(new_version) } pub fn get_media_folder(&self) -> PathBuf { diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs new file mode 100644 index 0000000..bf64001 --- /dev/null +++ b/src/service/key_backups/data.rs @@ -0,0 +1,78 @@ +use std::collections::BTreeMap; + +use crate::Result; +use ruma::{ + api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + serde::Raw, + OwnedRoomId, RoomId, UserId, +}; + +pub trait Data: Send + Sync { + fn create_backup( + &self, + user_id: &UserId, + backup_metadata: &Raw<BackupAlgorithm>, + ) -> Result<String>; + + fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; + + fn update_backup( + &self, + user_id: &UserId, + version: &str, + backup_metadata: &Raw<BackupAlgorithm>, + ) -> Result<String>; + + fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>; + + fn get_latest_backup(&self, user_id: &UserId) + -> Result<Option<(String, Raw<BackupAlgorithm>)>>; + + fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>; + + fn add_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + key_data: &Raw<KeyBackupData>, + ) -> Result<()>; + + fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>; + + fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>; + + fn get_all( + &self, + user_id: &UserId, + version: &str, + ) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>; + + fn get_room( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result<BTreeMap<String, Raw<KeyBackupData>>>; + + fn get_session( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result<Option<Raw<KeyBackupData>>>; + + fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; + + fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>; + + fn delete_room_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result<()>; +} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs new file mode 100644 index 0000000..5fc52ce --- /dev/null +++ b/src/service/key_backups/mod.rs @@ -0,0 +1,127 @@ +mod data; +pub use data::Data; + +use crate::Result; +use ruma::{ + api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + serde::Raw, + OwnedRoomId, RoomId, UserId, +}; +use std::collections::BTreeMap; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + pub fn create_backup( + &self, + user_id: &UserId, + backup_metadata: &Raw<BackupAlgorithm>, + ) -> Result<String> { + self.db.create_backup(user_id, backup_metadata) + } + + pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + self.db.delete_backup(user_id, version) + } + + pub fn update_backup( + &self, + user_id: &UserId, + version: &str, + backup_metadata: &Raw<BackupAlgorithm>, + ) -> Result<String> { + self.db.update_backup(user_id, version, backup_metadata) + } + + pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> { + self.db.get_latest_backup_version(user_id) + } + + pub fn get_latest_backup( + &self, + user_id: &UserId, + ) -> Result<Option<(String, Raw<BackupAlgorithm>)>> { + self.db.get_latest_backup(user_id) + } + + pub fn get_backup( + &self, + user_id: &UserId, + version: &str, + ) -> Result<Option<Raw<BackupAlgorithm>>> { + self.db.get_backup(user_id, version) + } + + pub fn add_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + key_data: &Raw<KeyBackupData>, + ) -> Result<()> { + self.db + .add_key(user_id, version, room_id, session_id, key_data) + } + + pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { + self.db.count_keys(user_id, version) + } + + pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { + self.db.get_etag(user_id, version) + } + + pub fn get_all( + &self, + user_id: &UserId, + version: &str, + ) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> { + self.db.get_all(user_id, version) + } + + pub fn get_room( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result<BTreeMap<String, Raw<KeyBackupData>>> { + self.db.get_room(user_id, version, room_id) + } + + pub fn get_session( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result<Option<Raw<KeyBackupData>>> { + self.db.get_session(user_id, version, room_id, session_id) + } + + pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + self.db.delete_all_keys(user_id, version) + } + + pub fn delete_room_keys( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + ) -> Result<()> { + self.db.delete_room_keys(user_id, version, room_id) + } + + pub fn delete_room_key( + &self, + user_id: &UserId, + version: &str, + room_id: &RoomId, + session_id: &str, + ) -> Result<()> { + self.db + .delete_room_key(user_id, version, room_id, session_id) + } +} diff --git a/src/service/media/data.rs b/src/service/media/data.rs new file mode 100644 index 0000000..75a682c --- /dev/null +++ b/src/service/media/data.rs @@ -0,0 +1,20 @@ +use crate::Result; + +pub trait Data: Send + Sync { + fn create_file_metadata( + &self, + mxc: String, + width: u32, + height: u32, + content_disposition: Option<&str>, + content_type: Option<&str>, + ) -> Result<Vec<u8>>; + + /// Returns content_disposition, content_type and the metadata key. + fn search_file_metadata( + &self, + mxc: String, + width: u32, + height: u32, + ) -> Result<(Option<String>, Option<String>, Vec<u8>)>; +} diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs new file mode 100644 index 0000000..9393753 --- /dev/null +++ b/src/service/media/mod.rs @@ -0,0 +1,226 @@ +mod data; +use std::io::Cursor; + +pub use data::Data; + +use crate::{services, Result}; +use image::imageops::FilterType; + +use tokio::{ + fs::File, + io::{AsyncReadExt, AsyncWriteExt}, +}; + +pub struct FileMeta { + pub content_disposition: Option<String>, + pub content_type: Option<String>, + pub file: Vec<u8>, +} + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Uploads a file. + pub async fn create( + &self, + mxc: String, + content_disposition: Option<&str>, + content_type: Option<&str>, + file: &[u8], + ) -> Result<()> { + // Width, Height = 0 if it's not a thumbnail + let key = self + .db + .create_file_metadata(mxc, 0, 0, content_disposition, content_type)?; + + let path = services().globals.get_media_file(&key); + let mut f = File::create(path).await?; + f.write_all(file).await?; + Ok(()) + } + + /// Uploads or replaces a file thumbnail. + #[allow(clippy::too_many_arguments)] + pub async fn upload_thumbnail( + &self, + mxc: String, + content_disposition: Option<&str>, + content_type: Option<&str>, + width: u32, + height: u32, + file: &[u8], + ) -> Result<()> { + let key = + self.db + .create_file_metadata(mxc, width, height, content_disposition, content_type)?; + + let path = services().globals.get_media_file(&key); + let mut f = File::create(path).await?; + f.write_all(file).await?; + + Ok(()) + } + + /// Downloads a file. + pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> { + if let Ok((content_disposition, content_type, key)) = + self.db.search_file_metadata(mxc, 0, 0) + { + let path = services().globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; + + Ok(Some(FileMeta { + content_disposition, + content_type, + file, + })) + } else { + Ok(None) + } + } + + /// Returns width, height of the thumbnail and whether it should be cropped. Returns None when + /// the server should send the original file. + pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { + match (width, height) { + (0..=32, 0..=32) => Some((32, 32, true)), + (0..=96, 0..=96) => Some((96, 96, true)), + (0..=320, 0..=240) => Some((320, 240, false)), + (0..=640, 0..=480) => Some((640, 480, false)), + (0..=800, 0..=600) => Some((800, 600, false)), + _ => None, + } + } + + /// Downloads a file's thumbnail. + /// + /// Here's an example on how it works: + /// + /// - Client requests an image with width=567, height=567 + /// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails + /// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) + /// - Server creates the thumbnail and sends it to the user + /// + /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. + pub async fn get_thumbnail( + &self, + mxc: String, + width: u32, + height: u32, + ) -> Result<Option<FileMeta>> { + let (width, height, crop) = self + .thumbnail_properties(width, height) + .unwrap_or((0, 0, false)); // 0, 0 because that's the original file + + if let Ok((content_disposition, content_type, key)) = + self.db.search_file_metadata(mxc.clone(), width, height) + { + // Using saved thumbnail + let path = services().globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; + + Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.to_vec(), + })) + } else if let Ok((content_disposition, content_type, key)) = + self.db.search_file_metadata(mxc.clone(), 0, 0) + { + // Generate a thumbnail + let path = services().globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; + + if let Ok(image) = image::load_from_memory(&file) { + let original_width = image.width(); + let original_height = image.height(); + if width > original_width || height > original_height { + return Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.to_vec(), + })); + } + + let thumbnail = if crop { + image.resize_to_fill(width, height, FilterType::CatmullRom) + } else { + let (exact_width, exact_height) = { + // Copied from image::dynimage::resize_dimensions + let ratio = u64::from(original_width) * u64::from(height); + let nratio = u64::from(width) * u64::from(original_height); + + let use_width = nratio <= ratio; + let intermediate = if use_width { + u64::from(original_height) * u64::from(width) + / u64::from(original_width) + } else { + u64::from(original_width) * u64::from(height) + / u64::from(original_height) + }; + if use_width { + if intermediate <= u64::from(::std::u32::MAX) { + (width, intermediate as u32) + } else { + ( + (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) + as u32, + ::std::u32::MAX, + ) + } + } else if intermediate <= u64::from(::std::u32::MAX) { + (intermediate as u32, height) + } else { + ( + ::std::u32::MAX, + (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) + as u32, + ) + } + }; + + image.thumbnail_exact(exact_width, exact_height) + }; + + let mut thumbnail_bytes = Vec::new(); + thumbnail.write_to( + &mut Cursor::new(&mut thumbnail_bytes), + image::ImageOutputFormat::Png, + )?; + + // Save thumbnail in database so we don't have to generate it again next time + let thumbnail_key = self.db.create_file_metadata( + mxc, + width, + height, + content_disposition.as_deref(), + content_type.as_deref(), + )?; + + let path = services().globals.get_media_file(&thumbnail_key); + let mut f = File::create(path).await?; + f.write_all(&thumbnail_bytes).await?; + + Ok(Some(FileMeta { + content_disposition, + content_type, + file: thumbnail_bytes.to_vec(), + })) + } else { + // Couldn't parse file to generate thumbnail, send original + Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.to_vec(), + })) + } + } else { + Ok(None) + } + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 0000000..385dcc6 --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,106 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use lru_cache::LruCache; + +use crate::{Config, Result}; + +pub mod account_data; +pub mod admin; +pub mod appservice; +pub mod globals; +pub mod key_backups; +pub mod media; +pub mod pdu; +pub mod pusher; +pub mod rooms; +pub mod sending; +pub mod transaction_ids; +pub mod uiaa; +pub mod users; + +pub struct Services { + pub appservice: appservice::Service, + pub pusher: pusher::Service, + pub rooms: rooms::Service, + pub transaction_ids: transaction_ids::Service, + pub uiaa: uiaa::Service, + pub users: users::Service, + pub account_data: account_data::Service, + pub admin: Arc<admin::Service>, + pub globals: globals::Service, + pub key_backups: key_backups::Service, + pub media: media::Service, + pub sending: Arc<sending::Service>, +} + +impl Services { + pub fn build< + D: appservice::Data + + pusher::Data + + rooms::Data + + transaction_ids::Data + + uiaa::Data + + users::Data + + account_data::Data + + globals::Data + + key_backups::Data + + media::Data + + sending::Data + + 'static, + >( + db: &'static D, + config: Config, + ) -> Result<Self> { + Ok(Self { + appservice: appservice::Service { db }, + pusher: pusher::Service { db }, + rooms: rooms::Service { + alias: rooms::alias::Service { db }, + auth_chain: rooms::auth_chain::Service { db }, + directory: rooms::directory::Service { db }, + edus: rooms::edus::Service { + presence: rooms::edus::presence::Service { db }, + read_receipt: rooms::edus::read_receipt::Service { db }, + typing: rooms::edus::typing::Service { db }, + }, + event_handler: rooms::event_handler::Service, + lazy_loading: rooms::lazy_loading::Service { + db, + lazy_load_waiting: Mutex::new(HashMap::new()), + }, + metadata: rooms::metadata::Service { db }, + outlier: rooms::outlier::Service { db }, + pdu_metadata: rooms::pdu_metadata::Service { db }, + search: rooms::search::Service { db }, + short: rooms::short::Service { db }, + state: rooms::state::Service { db }, + state_accessor: rooms::state_accessor::Service { db }, + state_cache: rooms::state_cache::Service { db }, + state_compressor: rooms::state_compressor::Service { + db, + stateinfo_cache: Mutex::new(LruCache::new( + (100.0 * config.conduit_cache_capacity_modifier) as usize, + )), + }, + timeline: rooms::timeline::Service { + db, + lasttimelinecount_cache: Mutex::new(HashMap::new()), + }, + user: rooms::user::Service { db }, + }, + transaction_ids: transaction_ids::Service { db }, + uiaa: uiaa::Service { db }, + users: users::Service { db }, + account_data: account_data::Service { db }, + admin: admin::Service::build(), + key_backups: key_backups::Service { db }, + media: media::Service { db }, + sending: sending::Service::build(db, &config), + + globals: globals::Service::load(db, config)?, + }) + } +} diff --git a/src/pdu.rs b/src/service/pdu.rs index 20ec01e..554f3be 100644 --- a/src/pdu.rs +++ b/src/service/pdu.rs @@ -1,11 +1,13 @@ -use crate::{Database, Error}; +use crate::Error; use ruma::{ events::{ - room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyRoomEvent, AnyStateEvent, - AnyStrippedStateEvent, AnySyncRoomEvent, AnySyncStateEvent, RoomEventType, StateEvent, + room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyStateEvent, + AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, + RoomEventType, StateEvent, }, - serde::{CanonicalJsonObject, CanonicalJsonValue, Raw}, - state_res, EventId, MilliSecondsSinceUnixEpoch, RoomId, UInt, UserId, + serde::Raw, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, + OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId, }; use serde::{Deserialize, Serialize}; use serde_json::{ @@ -25,8 +27,8 @@ pub struct EventHash { #[derive(Clone, Deserialize, Serialize, Debug)] pub struct PduEvent { pub event_id: Arc<EventId>, - pub room_id: Box<RoomId>, - pub sender: Box<UserId>, + pub room_id: OwnedRoomId, + pub sender: OwnedUserId, pub origin_server_ts: UInt, #[serde(rename = "type")] pub kind: RoomEventType, @@ -102,7 +104,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub fn to_sync_room_event(&self) -> Raw<AnySyncRoomEvent> { + pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> { let mut json = json!({ "content": self.content, "type": self.kind, @@ -146,7 +148,7 @@ impl PduEvent { } #[tracing::instrument(skip(self))] - pub fn to_room_event(&self) -> Raw<AnyRoomEvent> { + pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> { let mut json = json!({ "content": self.content, "type": self.kind, @@ -332,24 +334,17 @@ impl Ord for PduEvent { /// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`. pub(crate) fn gen_event_id_canonical_json( pdu: &RawJsonValue, - db: &Database, -) -> crate::Result<(Box<EventId>, CanonicalJsonObject)> { + room_version_id: &RoomVersionId, +) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { warn!("Error parsing incoming event {:?}: {:?}", pdu, e); Error::BadServerResponse("Invalid PDU in server response") })?; - let room_id = value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or_else(|| Error::bad_database("PDU in db has invalid room_id."))?; - - let room_version_id = db.rooms.get_room_version(&room_id); - let event_id = format!( "${}", // Anything higher than version3 behaves the same - ruma::signatures::reference_hash(&value, &room_version_id?) + ruma::signatures::reference_hash(&value, room_version_id) .expect("ruma can calculate reference hashes") ) .try_into() @@ -358,7 +353,7 @@ pub(crate) fn gen_event_id_canonical_json( Ok((event_id, value)) } -/// Build the start of a PDU in order to add it to the `Database`. +/// Build the start of a PDU in order to add it to the Database. #[derive(Debug, Deserialize)] pub struct PduBuilder { #[serde(rename = "type")] diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs new file mode 100644 index 0000000..2062f56 --- /dev/null +++ b/src/service/pusher/data.rs @@ -0,0 +1,16 @@ +use crate::Result; +use ruma::{ + api::client::push::{set_pusher, Pusher}, + UserId, +}; + +pub trait Data: Send + Sync { + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; + + fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>>; + + fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>; + + fn get_pushkeys<'a>(&'a self, sender: &UserId) + -> Box<dyn Iterator<Item = Result<String>> + 'a>; +} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs new file mode 100644 index 0000000..ba096a2 --- /dev/null +++ b/src/service/pusher/mod.rs @@ -0,0 +1,301 @@ +mod data; +pub use data::Data; +use ruma::events::AnySyncTimelineEvent; + +use crate::{services, Error, PduEvent, Result}; +use bytes::BytesMut; +use ruma::{ + api::{ + client::push::{set_pusher, Pusher, PusherKind}, + push_gateway::send_event_notification::{ + self, + v1::{Device, Notification, NotificationCounts, NotificationPriority}, + }, + IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, + }, + events::{ + room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, + RoomEventType, StateEventType, + }, + push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, + serde::Raw, + uint, RoomId, UInt, UserId, +}; + +use std::{fmt::Debug, mem}; +use tracing::{info, warn}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { + self.db.set_pusher(sender, pusher) + } + + pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> { + self.db.get_pusher(sender, pushkey) + } + + pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { + self.db.get_pushers(sender) + } + + pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> { + self.db.get_pushkeys(sender) + } + + #[tracing::instrument(skip(self, destination, request))] + pub async fn send_request<T: OutgoingRequest>( + &self, + destination: &str, + request: T, + ) -> Result<T::IncomingResponse> + where + T: Debug, + { + let destination = destination.replace("/_matrix/push/v1/notify", ""); + + let http_request = request + .try_into_http_request::<BytesMut>( + &destination, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + })? + .map(|body| body.freeze()); + + let reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); + + // TODO: we could keep this very short and let expo backoff do it's thing... + //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); + + let url = reqwest_request.url().clone(); + let response = services() + .globals + .default_client() + .execute(reqwest_request) + .await; + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { + info!( + "Push gateway returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + crate::utils::string_from_bytes(&body) + ); + } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + response.map_err(|_| { + info!( + "Push gateway returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Push gateway returned bad response.") + }) + } + Err(e) => { + warn!("Could not send request to pusher {}: {}", destination, e); + Err(e.into()) + } + } + } + + #[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))] + pub async fn send_push_notice( + &self, + user: &UserId, + unread: UInt, + pusher: &Pusher, + ruleset: Ruleset, + pdu: &PduEvent, + ) -> Result<()> { + let mut notify = None; + let mut tweaks = Vec::new(); + + let power_levels: RoomPowerLevelsEventContent = services() + .rooms + .state_accessor + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + for action in self.get_actions( + user, + &ruleset, + &power_levels, + &pdu.to_sync_room_event(), + &pdu.room_id, + )? { + let n = match action { + Action::DontNotify => false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => true, + Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + } + }; + + if notify.is_some() { + return Err(Error::bad_database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, + )); + } + + notify = Some(n); + } + + if notify == Some(true) { + self.send_notice(unread, pusher, tweaks, pdu).await?; + } + // Else the event triggered no actions + + Ok(()) + } + + #[tracing::instrument(skip(self, user, ruleset, pdu))] + pub fn get_actions<'a>( + &self, + user: &UserId, + ruleset: &'a Ruleset, + power_levels: &RoomPowerLevelsEventContent, + pdu: &Raw<AnySyncTimelineEvent>, + room_id: &RoomId, + ) -> Result<&'a [Action]> { + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: 10_u32.into(), // TODO: get member count efficiently + user_id: user.to_owned(), + user_display_name: services() + .users + .displayname(user)? + .unwrap_or_else(|| user.localpart().to_owned()), + users_power_levels: power_levels.users.clone(), + default_power_level: power_levels.users_default, + notification_power_levels: power_levels.notifications.clone(), + }; + + Ok(ruleset.get_actions(pdu, &ctx)) + } + + #[tracing::instrument(skip(self, unread, pusher, tweaks, event))] + async fn send_notice( + &self, + unread: UInt, + pusher: &Pusher, + tweaks: Vec<Tweak>, + event: &PduEvent, + ) -> Result<()> { + // TODO: email + match &pusher.kind { + PusherKind::Http(http) => { + // TODO: + // Two problems with this + // 1. if "event_id_only" is the only format kind it seems we should never add more info + // 2. can pusher/devices have conflicting formats + let event_id_only = http.format == Some(PushFormat::EventIdOnly); + + let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); + device.data.default_payload = http.default_payload.clone(); + device.data.format = http.format.clone(); + + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks = tweaks.clone(); + } + + let d = vec![device]; + let mut notifi = Notification::new(d); + + notifi.prio = NotificationPriority::Low; + notifi.event_id = Some((*event.event_id).to_owned()); + notifi.room_id = Some((*event.room_id).to_owned()); + // TODO: missed calls + notifi.counts = NotificationCounts::new(unread, uint!(0)); + + if event.kind == RoomEventType::RoomEncrypted + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High + } + + if event_id_only { + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) + .await?; + } else { + notifi.sender = Some(event.sender.clone()); + notifi.event_type = Some(event.kind.clone()); + notifi.content = serde_json::value::to_raw_value(&event.content).ok(); + + if event.kind == RoomEventType::RoomMember { + notifi.user_is_target = + event.state_key.as_deref() == Some(event.sender.as_str()); + } + + notifi.sender_display_name = services().users.displayname(&event.sender)?; + + let room_name = if let Some(room_name_pdu) = services() + .rooms + .state_accessor + .room_state_get(&event.room_id, &StateEventType::RoomName, "")? + { + serde_json::from_str::<RoomNameEventContent>(room_name_pdu.content.get()) + .map_err(|_| { + Error::bad_database("Invalid room name event in database.") + })? + .name + } else { + None + }; + + notifi.room_name = room_name; + + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) + .await?; + } + + Ok(()) + } + // TODO: Handle email + PusherKind::Email(_) => Ok(()), + _ => Ok(()), + } + } +} diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs new file mode 100644 index 0000000..629b1ee --- /dev/null +++ b/src/service/rooms/alias/data.rs @@ -0,0 +1,19 @@ +use crate::Result; +use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; + +pub trait Data: Send + Sync { + /// Creates or updates the alias to the given room id. + fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>; + + /// Forgets about an alias. Returns an error if the alias did not exist. + fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>; + + /// Looks up the roomid for the given alias. + fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>>; + + /// Returns all local aliases that point to the given room + fn local_aliases_for_room<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>; +} diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs new file mode 100644 index 0000000..d26030c --- /dev/null +++ b/src/service/rooms/alias/mod.rs @@ -0,0 +1,35 @@ +mod data; + +pub use data::Data; + +use crate::Result; +use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + #[tracing::instrument(skip(self))] + pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { + self.db.set_alias(alias, room_id) + } + + #[tracing::instrument(skip(self))] + pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { + self.db.remove_alias(alias) + } + + #[tracing::instrument(skip(self))] + pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { + self.db.resolve_local_alias(alias) + } + + #[tracing::instrument(skip(self))] + pub fn local_aliases_for_room<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> { + self.db.local_aliases_for_room(room_id) + } +} diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs new file mode 100644 index 0000000..e8c379f --- /dev/null +++ b/src/service/rooms/auth_chain/data.rs @@ -0,0 +1,11 @@ +use crate::Result; +use std::{collections::HashSet, sync::Arc}; + +pub trait Data: Send + Sync { + fn get_cached_eventid_authchain( + &self, + shorteventid: &[u64], + ) -> Result<Option<Arc<HashSet<u64>>>>; + fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>) + -> Result<()>; +} diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs new file mode 100644 index 0000000..da1944e --- /dev/null +++ b/src/service/rooms/auth_chain/mod.rs @@ -0,0 +1,161 @@ +mod data; +use std::{ + collections::{BTreeSet, HashSet}, + sync::Arc, +}; + +pub use data::Data; +use ruma::{api::client::error::ErrorKind, EventId, RoomId}; +use tracing::{debug, error, warn}; + +use crate::{services, Error, Result}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { + self.db.get_cached_eventid_authchain(key) + } + + #[tracing::instrument(skip(self))] + pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> { + self.db.cache_auth_chain(key, auth_chain) + } + + #[tracing::instrument(skip(self, starting_events))] + pub async fn get_auth_chain<'a>( + &self, + room_id: &RoomId, + starting_events: Vec<Arc<EventId>>, + ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { + const NUM_BUCKETS: usize = 50; + + let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; + + let mut i = 0; + for id in starting_events { + let short = services().rooms.short.get_or_create_shorteventid(&id)?; + let bucket_id = (short % NUM_BUCKETS as u64) as usize; + buckets[bucket_id].insert((short, id.clone())); + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + + let mut full_auth_chain = HashSet::new(); + + let mut hits = 0; + let mut misses = 0; + for chunk in buckets { + if chunk.is_empty() { + continue; + } + + let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); + if let Some(cached) = services() + .rooms + .auth_chain + .get_cached_eventid_authchain(&chunk_key)? + { + hits += 1; + full_auth_chain.extend(cached.iter().copied()); + continue; + } + misses += 1; + + let mut chunk_cache = HashSet::new(); + let mut hits2 = 0; + let mut misses2 = 0; + let mut i = 0; + for (sevent_id, event_id) in chunk { + if let Some(cached) = services() + .rooms + .auth_chain + .get_cached_eventid_authchain(&[sevent_id])? + { + hits2 += 1; + chunk_cache.extend(cached.iter().copied()); + } else { + misses2 += 1; + let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); + services() + .rooms + .auth_chain + .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; + debug!( + event_id = ?event_id, + chain_length = ?auth_chain.len(), + "Cache missed event" + ); + chunk_cache.extend(auth_chain.iter()); + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + }; + } + debug!( + chunk_cache_length = ?chunk_cache.len(), + hits = ?hits2, + misses = ?misses2, + "Chunk missed", + ); + let chunk_cache = Arc::new(chunk_cache); + services() + .rooms + .auth_chain + .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; + full_auth_chain.extend(chunk_cache.iter()); + } + + debug!( + chain_length = ?full_auth_chain.len(), + hits = ?hits, + misses = ?misses, + "Auth chain stats", + ); + + Ok(full_auth_chain + .into_iter() + .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) + } + + #[tracing::instrument(skip(self, event_id))] + fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> { + let mut todo = vec![Arc::from(event_id)]; + let mut found = HashSet::new(); + + while let Some(event_id) = todo.pop() { + match services().rooms.timeline.get_pdu(&event_id) { + Ok(Some(pdu)) => { + if pdu.room_id != room_id { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); + } + for auth_event in &pdu.auth_events { + let sauthevent = services() + .rooms + .short + .get_or_create_shorteventid(auth_event)?; + + if !found.contains(&sauthevent) { + found.insert(sauthevent); + todo.push(auth_event.clone()); + } + } + } + Ok(None) => { + warn!(?event_id, "Could not find pdu mentioned in auth events"); + } + Err(error) => { + error!(?event_id, ?error, "Could not load event in auth chain"); + } + } + } + + Ok(found) + } +} diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs new file mode 100644 index 0000000..aca731c --- /dev/null +++ b/src/service/rooms/directory/data.rs @@ -0,0 +1,16 @@ +use crate::Result; +use ruma::{OwnedRoomId, RoomId}; + +pub trait Data: Send + Sync { + /// Adds the room to the public room directory + fn set_public(&self, room_id: &RoomId) -> Result<()>; + + /// Removes the room from the public room directory. + fn set_not_public(&self, room_id: &RoomId) -> Result<()>; + + /// Returns true if the room is in the public room directory. + fn is_public_room(&self, room_id: &RoomId) -> Result<bool>; + + /// Returns the unsorted public room directory + fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; +} diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs new file mode 100644 index 0000000..10f782b --- /dev/null +++ b/src/service/rooms/directory/mod.rs @@ -0,0 +1,32 @@ +mod data; + +pub use data::Data; +use ruma::{OwnedRoomId, RoomId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + #[tracing::instrument(skip(self))] + pub fn set_public(&self, room_id: &RoomId) -> Result<()> { + self.db.set_public(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { + self.db.set_not_public(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { + self.db.is_public_room(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { + self.db.public_rooms() + } +} diff --git a/src/service/rooms/edus/mod.rs b/src/service/rooms/edus/mod.rs new file mode 100644 index 0000000..cf7a359 --- /dev/null +++ b/src/service/rooms/edus/mod.rs @@ -0,0 +1,11 @@ +pub mod presence; +pub mod read_receipt; +pub mod typing; + +pub trait Data: presence::Data + read_receipt::Data + typing::Data + 'static {} + +pub struct Service { + pub presence: presence::Service, + pub read_receipt: read_receipt::Service, + pub typing: typing::Service, +} diff --git a/src/service/rooms/edus/presence/data.rs b/src/service/rooms/edus/presence/data.rs new file mode 100644 index 0000000..53329e0 --- /dev/null +++ b/src/service/rooms/edus/presence/data.rs @@ -0,0 +1,38 @@ +use std::collections::HashMap; + +use crate::Result; +use ruma::{events::presence::PresenceEvent, OwnedUserId, RoomId, UserId}; + +pub trait Data: Send + Sync { + /// Adds a presence event which will be saved until a new event replaces it. + /// + /// Note: This method takes a RoomId because presence updates are always bound to rooms to + /// make sure users outside these rooms can't see them. + fn update_presence( + &self, + user_id: &UserId, + room_id: &RoomId, + presence: PresenceEvent, + ) -> Result<()>; + + /// Resets the presence timeout, so the user will stay in their current presence state. + fn ping_presence(&self, user_id: &UserId) -> Result<()>; + + /// Returns the timestamp of the last presence update of this user in millis since the unix epoch. + fn last_presence_update(&self, user_id: &UserId) -> Result<Option<u64>>; + + /// Returns the presence event with correct last_active_ago. + fn get_presence_event( + &self, + room_id: &RoomId, + user_id: &UserId, + count: u64, + ) -> Result<Option<PresenceEvent>>; + + /// Returns the most recent presence updates that happened after the event with id `since`. + fn presence_since( + &self, + room_id: &RoomId, + since: u64, + ) -> Result<HashMap<OwnedUserId, PresenceEvent>>; +} diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs new file mode 100644 index 0000000..860aea1 --- /dev/null +++ b/src/service/rooms/edus/presence/mod.rs @@ -0,0 +1,122 @@ +mod data; +use std::collections::HashMap; + +pub use data::Data; +use ruma::{events::presence::PresenceEvent, OwnedUserId, RoomId, UserId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Adds a presence event which will be saved until a new event replaces it. + /// + /// Note: This method takes a RoomId because presence updates are always bound to rooms to + /// make sure users outside these rooms can't see them. + pub fn update_presence( + &self, + user_id: &UserId, + room_id: &RoomId, + presence: PresenceEvent, + ) -> Result<()> { + self.db.update_presence(user_id, room_id, presence) + } + + /// Resets the presence timeout, so the user will stay in their current presence state. + pub fn ping_presence(&self, user_id: &UserId) -> Result<()> { + self.db.ping_presence(user_id) + } + + pub fn get_last_presence_event( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<Option<PresenceEvent>> { + let last_update = match self.db.last_presence_update(user_id)? { + Some(last) => last, + None => return Ok(None), + }; + + self.db.get_presence_event(room_id, user_id, last_update) + } + + /* TODO + /// Sets all users to offline who have been quiet for too long. + fn _presence_maintain( + &self, + rooms: &super::Rooms, + globals: &super::super::globals::Globals, + ) -> Result<()> { + let current_timestamp = utils::millis_since_unix_epoch(); + + for (user_id_bytes, last_timestamp) in self + .userid_lastpresenceupdate + .iter() + .filter_map(|(k, bytes)| { + Some(( + k, + utils::u64_from_bytes(&bytes) + .map_err(|_| { + Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.") + }) + .ok()?, + )) + }) + .take_while(|(_, timestamp)| current_timestamp.saturating_sub(*timestamp) > 5 * 60_000) + // 5 Minutes + { + // Send new presence events to set the user offline + let count = globals.next_count()?.to_be_bytes(); + let user_id: Box<_> = utils::string_from_bytes(&user_id_bytes) + .map_err(|_| { + Error::bad_database("Invalid UserId bytes in userid_lastpresenceupdate.") + })? + .try_into() + .map_err(|_| Error::bad_database("Invalid UserId in userid_lastpresenceupdate."))?; + for room_id in rooms.rooms_joined(&user_id).filter_map(|r| r.ok()) { + let mut presence_id = room_id.as_bytes().to_vec(); + presence_id.push(0xff); + presence_id.extend_from_slice(&count); + presence_id.push(0xff); + presence_id.extend_from_slice(&user_id_bytes); + + self.presenceid_presence.insert( + &presence_id, + &serde_json::to_vec(&PresenceEvent { + content: PresenceEventContent { + avatar_url: None, + currently_active: None, + displayname: None, + last_active_ago: Some( + last_timestamp.try_into().expect("time is valid"), + ), + presence: PresenceState::Offline, + status_msg: None, + }, + sender: user_id.to_owned(), + }) + .expect("PresenceEvent can be serialized"), + )?; + } + + self.userid_lastpresenceupdate.insert( + user_id.as_bytes(), + &utils::millis_since_unix_epoch().to_be_bytes(), + )?; + } + + Ok(()) + }*/ + + /// Returns the most recent presence updates that happened after the event with id `since`. + #[tracing::instrument(skip(self, since, room_id))] + pub fn presence_since( + &self, + room_id: &RoomId, + since: u64, + ) -> Result<HashMap<OwnedUserId, PresenceEvent>> { + self.db.presence_since(room_id, since) + } +} diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs new file mode 100644 index 0000000..a183d19 --- /dev/null +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -0,0 +1,36 @@ +use crate::Result; +use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; + +pub trait Data: Send + Sync { + /// Replaces the previous read receipt. + fn readreceipt_update( + &self, + user_id: &UserId, + room_id: &RoomId, + event: ReceiptEvent, + ) -> Result<()>; + + /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. + fn readreceipts_since<'a>( + &'a self, + room_id: &RoomId, + since: u64, + ) -> Box< + dyn Iterator< + Item = Result<( + OwnedUserId, + u64, + Raw<ruma::events::AnySyncEphemeralRoomEvent>, + )>, + > + 'a, + >; + + /// Sets a private read marker at `count`. + fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; + + /// Returns the private read marker. + fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>; + + /// Returns the count of the last typing update in this room. + fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; +} diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs new file mode 100644 index 0000000..c603528 --- /dev/null +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -0,0 +1,55 @@ +mod data; + +pub use data::Data; + +use crate::Result; +use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Replaces the previous read receipt. + pub fn readreceipt_update( + &self, + user_id: &UserId, + room_id: &RoomId, + event: ReceiptEvent, + ) -> Result<()> { + self.db.readreceipt_update(user_id, room_id, event) + } + + /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. + #[tracing::instrument(skip(self))] + pub fn readreceipts_since<'a>( + &'a self, + room_id: &RoomId, + since: u64, + ) -> impl Iterator< + Item = Result<( + OwnedUserId, + u64, + Raw<ruma::events::AnySyncEphemeralRoomEvent>, + )>, + > + 'a { + self.db.readreceipts_since(room_id, since) + } + + /// Sets a private read marker at `count`. + #[tracing::instrument(skip(self))] + pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + self.db.private_read_set(room_id, user_id, count) + } + + /// Returns the private read marker. + #[tracing::instrument(skip(self))] + pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { + self.db.private_read_get(room_id, user_id) + } + + /// Returns the count of the last typing update in this room. + pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + self.db.last_privateread_update(user_id, room_id) + } +} diff --git a/src/service/rooms/edus/typing/data.rs b/src/service/rooms/edus/typing/data.rs new file mode 100644 index 0000000..3b1eecf --- /dev/null +++ b/src/service/rooms/edus/typing/data.rs @@ -0,0 +1,21 @@ +use crate::Result; +use ruma::{OwnedUserId, RoomId, UserId}; +use std::collections::HashSet; + +pub trait Data: Send + Sync { + /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is + /// called. + fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()>; + + /// Removes a user from typing before the timeout is reached. + fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + + /// Makes sure that typing events with old timestamps get removed. + fn typings_maintain(&self, room_id: &RoomId) -> Result<()>; + + /// Returns the count of the last typing update in this room. + fn last_typing_update(&self, room_id: &RoomId) -> Result<u64>; + + /// Returns all user ids currently typing. + fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>>; +} diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs new file mode 100644 index 0000000..7d44f7d --- /dev/null +++ b/src/service/rooms/edus/typing/mod.rs @@ -0,0 +1,49 @@ +mod data; + +pub use data::Data; +use ruma::{events::SyncEphemeralRoomEvent, RoomId, UserId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is + /// called. + pub fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { + self.db.typing_add(user_id, room_id, timeout) + } + + /// Removes a user from typing before the timeout is reached. + pub fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + self.db.typing_remove(user_id, room_id) + } + + /// Makes sure that typing events with old timestamps get removed. + fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { + self.db.typings_maintain(room_id) + } + + /// Returns the count of the last typing update in this room. + pub fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> { + self.typings_maintain(room_id)?; + + self.db.last_typing_update(room_id) + } + + /// Returns a new typing EDU. + pub fn typings_all( + &self, + room_id: &RoomId, + ) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> { + let user_ids = self.db.typings_all(room_id)?; + + Ok(SyncEphemeralRoomEvent { + content: ruma::events::typing::TypingEventContent { + user_ids: user_ids.into_iter().collect(), + }, + }) + } +} diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs new file mode 100644 index 0000000..bc67f7a --- /dev/null +++ b/src/service/rooms/event_handler/mod.rs @@ -0,0 +1,1741 @@ +/// An async function that can recursively call itself. +type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>; + +use ruma::{ + api::federation::discovery::{get_remote_server_keys, get_server_keys}, + CanonicalJsonObject, CanonicalJsonValue, OwnedServerName, OwnedServerSigningKeyId, + RoomVersionId, +}; +use std::{ + collections::{hash_map, BTreeMap, HashMap, HashSet}, + pin::Pin, + sync::{Arc, RwLock, RwLockWriteGuard}, + time::{Duration, Instant, SystemTime}, +}; +use tokio::sync::Semaphore; + +use futures_util::{stream::FuturesUnordered, Future, StreamExt}; +use ruma::{ + api::{ + client::error::ErrorKind, + federation::{ + discovery::get_remote_server_keys_batch::{self, v2::QueryCriteria}, + event::{get_event, get_room_state_ids}, + membership::create_join_event, + }, + }, + events::{ + room::{create::RoomCreateEventContent, server_acl::RoomServerAclEventContent}, + StateEventType, + }, + int, + serde::Base64, + state_res::{self, RoomVersion, StateMap}, + uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, +}; +use serde_json::value::RawValue as RawJsonValue; +use tracing::{debug, error, info, trace, warn}; + +use crate::{service::*, services, Error, PduEvent, Result}; + +pub struct Service; + +impl Service { + /// When receiving an event one needs to: + /// 0. Check the server is in the room + /// 1. Skip the PDU if we already know about it + /// 1.1. Remove unsigned field + /// 2. Check signatures, otherwise drop + /// 3. Check content hash, redact if doesn't match + /// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not + /// timeline events + /// 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are + /// also rejected "due to auth events" + /// 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events + /// 7. Persist this event as an outlier + /// 8. If not timeline event: stop + /// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline + /// events + /// 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities + /// doing all the checks in this list starting at 1. These are not timeline events + /// 11. Check the auth of the event passes based on the state of the event + /// 12. Ensure that the state is derived from the previous current state (i.e. we calculated by + /// doing state res where one of the inputs was a previously trusted set of state, don't just + /// trust a set of state we got from a remote) + /// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" + /// it + /// 14. Use state resolution to find new room state + // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively + #[tracing::instrument(skip(self, value, is_timeline_event, pub_key_map))] + pub(crate) async fn handle_incoming_pdu<'a>( + &self, + origin: &'a ServerName, + event_id: &'a EventId, + room_id: &'a RoomId, + value: BTreeMap<String, CanonicalJsonValue>, + is_timeline_event: bool, + pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + ) -> Result<Option<Vec<u8>>> { + // 0. Check the server is in the room + if !services().rooms.metadata.exists(room_id)? { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Room is unknown to this server", + )); + } + + if services().rooms.metadata.is_disabled(room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Federation of this room is currently disabled on this server.", + )); + } + + // 1. Skip the PDU if we already have it as a timeline event + if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? { + return Ok(Some(pdu_id.to_vec())); + } + + let create_event = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")? + .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; + + let create_event_content: RoomCreateEventContent = + serde_json::from_str(create_event.content.get()).map_err(|e| { + error!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; + let room_version_id = &create_event_content.room_version; + + let first_pdu_in_room = services() + .rooms + .timeline + .first_pdu_in_room(room_id)? + .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + + let (incoming_pdu, val) = self + .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, pub_key_map) + .await?; + + // 8. if not timeline event: stop + if !is_timeline_event { + return Ok(None); + } + + // Skip old events + if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + return Ok(None); + } + + // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + let (sorted_prev_events, mut eventid_info) = self + .fetch_unknown_prev_events( + origin, + &create_event, + room_id, + room_version_id, + pub_key_map, + incoming_pdu.prev_events.clone(), + ) + .await?; + + let mut errors = 0; + debug!(events = ?sorted_prev_events, "Got previous events"); + for prev_id in sorted_prev_events { + // Check for disabled again because it might have changed + if services().rooms.metadata.is_disabled(room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Federation of this room is currently disabled on this server.", + )); + } + + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&*prev_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", prev_id); + continue; + } + } + + if errors >= 5 { + break; + } + + if let Some((pdu, json)) = eventid_info.remove(&*prev_id) { + // Skip old events + if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + continue; + } + + let start_time = Instant::now(); + services() + .globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); + + if let Err(e) = self + .upgrade_outlier_to_timeline_pdu( + pdu, + json, + &create_event, + origin, + room_id, + pub_key_map, + ) + .await + { + errors += 1; + warn!("Prev event {} failed: {}", prev_id, e); + match services() + .globals + .bad_event_ratelimiter + .write() + .unwrap() + .entry((*prev_id).to_owned()) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1) + } + } + } + let elapsed = start_time.elapsed(); + services() + .globals + .roomid_federationhandletime + .write() + .unwrap() + .remove(&room_id.to_owned()); + warn!( + "Handling prev event {} took {}m{}s", + prev_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } + } + + // Done with prev events, now handling the incoming event + + let start_time = Instant::now(); + services() + .globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); + let r = services() + .rooms + .event_handler + .upgrade_outlier_to_timeline_pdu( + incoming_pdu, + val, + &create_event, + origin, + room_id, + pub_key_map, + ) + .await; + services() + .globals + .roomid_federationhandletime + .write() + .unwrap() + .remove(&room_id.to_owned()); + + r + } + + #[tracing::instrument(skip(self, create_event, value, pub_key_map))] + fn handle_outlier_pdu<'a>( + &'a self, + origin: &'a ServerName, + create_event: &'a PduEvent, + event_id: &'a EventId, + room_id: &'a RoomId, + mut value: BTreeMap<String, CanonicalJsonValue>, + pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + ) -> AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>> { + Box::pin(async move { + // 1.1. Remove unsigned field + value.remove("unsigned"); + + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + + // We go through all the signatures we see on the value and fetch the corresponding signing + // keys + self.fetch_required_signing_keys(&value, pub_key_map) + .await?; + + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let create_event_content: RoomCreateEventContent = + serde_json::from_str(create_event.content.get()).map_err(|e| { + error!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; + + let room_version_id = &create_event_content.room_version; + let room_version = + RoomVersion::new(room_version_id).expect("room version is supported"); + + let mut val = match ruma::signatures::verify_event( + &pub_key_map.read().expect("RwLock is poisoned."), + &value, + room_version_id, + ) { + Err(e) => { + // Drop + warn!("Dropping bad event {}: {}", event_id, e); + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Signature verification failed", + )); + } + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + warn!("Calculated hash does not match: {}", event_id); + match ruma::canonical_json::redact(value, room_version_id, None) { + Ok(obj) => obj, + Err(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Redaction failed", + )) + } + } + } + Ok(ruma::signatures::Verified::All) => value, + }; + + // Now that we have checked the signature and hashes we can add the eventID and convert + // to our PduEvent type + val.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + let incoming_pdu = serde_json::from_value::<PduEvent>( + serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), + ) + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; + + // 4. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + debug!(event_id = ?incoming_pdu.event_id, "Fetching auth events"); + self.fetch_and_handle_outliers( + origin, + &incoming_pdu + .auth_events + .iter() + .map(|x| Arc::from(&**x)) + .collect::<Vec<_>>(), + create_event, + room_id, + room_version_id, + pub_key_map, + ) + .await; + + // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events + info!( + "Auth check for {} based on auth events", + incoming_pdu.event_id + ); + + // Build map of auth events + let mut auth_events = HashMap::new(); + for id in &incoming_pdu.auth_events { + let auth_event = match services().rooms.timeline.get_pdu(id)? { + Some(e) => e, + None => { + warn!("Could not find auth event {}", id); + continue; + } + }; + + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + } + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", + )); + } + } + } + + // The original create event must be in the auth events + if auth_events + .get(&(StateEventType::RoomCreate, "".to_owned())) + .map(|a| a.as_ref()) + != Some(create_event) + { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); + } + + if !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None::<PduEvent>, // TODO: third party invite + |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), + ) + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? + { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth check failed", + )); + } + + info!("Validation successful."); + + // 7. Persist the event as an outlier. + services() + .rooms + .outlier + .add_pdu_outlier(&incoming_pdu.event_id, &val)?; + + info!("Added pdu as outlier."); + + Ok((Arc::new(incoming_pdu), val)) + }) + } + + #[tracing::instrument(skip(self, incoming_pdu, val, create_event, pub_key_map))] + pub async fn upgrade_outlier_to_timeline_pdu( + &self, + incoming_pdu: Arc<PduEvent>, + val: BTreeMap<String, CanonicalJsonValue>, + create_event: &PduEvent, + origin: &ServerName, + room_id: &RoomId, + pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + ) -> Result<Option<Vec<u8>>> { + // Skip the PDU if we already have it as a timeline event + if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) { + return Ok(Some(pduid)); + } + + if services() + .rooms + .pdu_metadata + .is_event_soft_failed(&incoming_pdu.event_id)? + { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Event has been soft failed", + )); + } + + info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); + + let create_event_content: RoomCreateEventContent = + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; + + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + + // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities + // doing all the checks in this list starting at 1. These are not timeline events. + + // TODO: if we know the prev_events of the incoming event we can avoid the request and build + // the state from a known point and resolve if > 1 prev_event + + info!("Requesting state at event"); + let mut state_at_incoming_event = None; + + if incoming_pdu.prev_events.len() == 1 { + let prev_event = &*incoming_pdu.prev_events[0]; + let prev_event_sstatehash = services() + .rooms + .state_accessor + .pdu_shortstatehash(prev_event)?; + + let state = if let Some(shortstatehash) = prev_event_sstatehash { + Some( + services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await, + ) + } else { + None + }; + + if let Some(Ok(mut state)) = state { + info!("Using cached state"); + let prev_pdu = services() + .rooms + .timeline + .get_pdu(prev_event) + .ok() + .flatten() + .ok_or_else(|| { + Error::bad_database("Could not find prev event, but we know the state.") + })?; + + if let Some(state_key) = &prev_pdu.state_key { + let shortstatekey = services().rooms.short.get_or_create_shortstatekey( + &prev_pdu.kind.to_string().into(), + state_key, + )?; + + state.insert(shortstatekey, Arc::from(prev_event)); + // Now it's the state after the pdu + } + + state_at_incoming_event = Some(state); + } + } else { + info!("Calculating state at event using state res"); + let mut extremity_sstatehashes = HashMap::new(); + + let mut okay = true; + for prev_eventid in &incoming_pdu.prev_events { + let prev_event = + if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(prev_eventid) { + pdu + } else { + okay = false; + break; + }; + + let sstatehash = if let Ok(Some(s)) = services() + .rooms + .state_accessor + .pdu_shortstatehash(prev_eventid) + { + s + } else { + okay = false; + break; + }; + + extremity_sstatehashes.insert(sstatehash, prev_event); + } + + if okay { + let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); + let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); + + for (sstatehash, prev_event) in extremity_sstatehashes { + let mut leaf_state: HashMap<_, _> = services() + .rooms + .state_accessor + .state_full_ids(sstatehash) + .await?; + + if let Some(state_key) = &prev_event.state_key { + let shortstatekey = services().rooms.short.get_or_create_shortstatekey( + &prev_event.kind.to_string().into(), + state_key, + )?; + leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + // Now it's the state after the pdu + } + + let mut state = StateMap::with_capacity(leaf_state.len()); + let mut starting_events = Vec::with_capacity(leaf_state.len()); + + for (k, id) in leaf_state { + if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) + { + // FIXME: Undo .to_string().into() when StateMap + // is updated to use StateEventType + state.insert((ty.to_string().into(), st_key), id.clone()); + } else { + warn!("Failed to get_statekey_from_short."); + } + starting_events.push(id); + } + + auth_chain_sets.push( + services() + .rooms + .auth_chain + .get_auth_chain(room_id, starting_events) + .await? + .collect(), + ); + + fork_states.push(state); + } + + let lock = services().globals.stateres_mutex.lock(); + + let result = + state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { + let res = services().rooms.timeline.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }); + drop(lock); + + state_at_incoming_event = match result { + Ok(new_state) => Some( + new_state + .into_iter() + .map(|((event_type, state_key), event_id)| { + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey( + &event_type.to_string().into(), + &state_key, + )?; + Ok((shortstatekey, event_id)) + }) + .collect::<Result<_>>()?, + ), + Err(e) => { + warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); + None + } + } + } + } + + if state_at_incoming_event.is_none() { + info!("Calling /state_ids"); + // Call /state_ids to find out what the state at this pdu is. We trust the server's + // response to some extend, but we still do a lot of checks on the events + match services() + .sending + .send_federation_request( + origin, + get_room_state_ids::v1::Request { + room_id: room_id.to_owned(), + event_id: (*incoming_pdu.event_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + info!("Fetching state events at event."); + let state_vec = self + .fetch_and_handle_outliers( + origin, + &res.pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::<Vec<_>>(), + create_event, + room_id, + room_version_id, + pub_key_map, + ) + .await; + + let mut state: HashMap<_, Arc<EventId>> = HashMap::new(); + for (pdu, _) in state_vec { + let state_key = pdu.state_key.clone().ok_or_else(|| { + Error::bad_database("Found non-state pdu in state events.") + })?; + + let shortstatekey = services().rooms.short.get_or_create_shortstatekey( + &pdu.kind.to_string().into(), + &state_key, + )?; + + match state.entry(shortstatekey) { + hash_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + } + hash_map::Entry::Occupied(_) => return Err( + Error::bad_database("State event's type and state_key combination exists multiple times."), + ), + } + } + + // The original create event must still be in the state + let create_shortstatekey = services() + .rooms + .short + .get_shortstatekey(&StateEventType::RoomCreate, "")? + .expect("Room exists"); + + if state.get(&create_shortstatekey).map(|id| id.as_ref()) + != Some(&create_event.event_id) + { + return Err(Error::bad_database( + "Incoming event refers to wrong create event.", + )); + } + + state_at_incoming_event = Some(state); + } + Err(e) => { + warn!("Fetching state for event failed: {}", e); + return Err(e); + } + }; + } + + let state_at_incoming_event = + state_at_incoming_event.expect("we always set this to some above"); + + info!("Starting auth check"); + // 11. Check the auth of the event passes based on the state of the event + let check_result = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None::<PduEvent>, // TODO: third party invite + |k, s| { + services() + .rooms + .short + .get_shortstatekey(&k.to_string().into(), s) + .ok() + .flatten() + .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) + .and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) + }, + ) + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; + + if !check_result { + return Err(Error::bad_database( + "Event has failed auth check with state at the event.", + )); + } + info!("Auth check succeeded"); + + // We start looking at current room state now, so lets lock the room + + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Now we calculate the set of extremities this room has after the incoming event has been + // applied. We start with the previous extremities (aka leaves) + info!("Calculating extremities"); + let mut extremities = services().rooms.state.get_forward_extremities(room_id)?; + + // Remove any forward extremities that are referenced by this incoming event's prev_events + for prev_event in &incoming_pdu.prev_events { + if extremities.contains(prev_event) { + extremities.remove(prev_event); + } + } + + // Only keep those extremities were not referenced yet + extremities.retain(|id| { + !matches!( + services() + .rooms + .pdu_metadata + .is_event_referenced(room_id, id), + Ok(true) + ) + }); + + info!("Compressing state at event"); + let state_ids_compressed = state_at_incoming_event + .iter() + .map(|(shortstatekey, id)| { + services() + .rooms + .state_compressor + .compress_state_event(*shortstatekey, id) + }) + .collect::<Result<_>>()?; + + // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it + info!("Starting soft fail auth check"); + + let auth_events = services().rooms.state.get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + )?; + + let soft_fail = !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None::<PduEvent>, + |k, s| auth_events.get(&(k.clone(), s.to_owned())), + ) + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; + + if soft_fail { + services().rooms.timeline.append_incoming_pdu( + &incoming_pdu, + val, + extremities.iter().map(|e| (**e).to_owned()).collect(), + state_ids_compressed, + soft_fail, + &state_lock, + )?; + + // Soft fail, we keep the event as an outlier but don't add it to the timeline + warn!("Event was soft failed: {:?}", incoming_pdu); + services() + .rooms + .pdu_metadata + .mark_event_soft_failed(&incoming_pdu.event_id)?; + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Event has been soft failed", + )); + } + + if incoming_pdu.state_key.is_some() { + info!("Loading current room state ids"); + let current_sstatehash = services() + .rooms + .state + .get_room_shortstatehash(room_id)? + .expect("every room has state"); + + let current_state_ids = services() + .rooms + .state_accessor + .state_full_ids(current_sstatehash) + .await?; + + info!("Preparing for stateres to derive new room state"); + let mut extremity_sstatehashes = HashMap::new(); + + info!(?extremities, "Loading extremities"); + for id in &extremities { + match services().rooms.timeline.get_pdu(id)? { + Some(leaf_pdu) => { + extremity_sstatehashes.insert( + services() + .rooms + .state_accessor + .pdu_shortstatehash(&leaf_pdu.event_id)? + .ok_or_else(|| { + error!( + "Found extremity pdu with no statehash in db: {:?}", + leaf_pdu + ); + Error::bad_database("Found pdu with no statehash in db.") + })?, + leaf_pdu, + ); + } + _ => { + error!("Missing state snapshot for {:?}", id); + return Err(Error::BadDatabase("Missing state snapshot.")); + } + } + } + + let mut fork_states = Vec::new(); + + // 12. Ensure that the state is derived from the previous current state (i.e. we calculated + // by doing state res where one of the inputs was a previously trusted set of state, + // don't just trust a set of state we got from a remote). + + // We do this by adding the current state to the list of fork states + extremity_sstatehashes.remove(¤t_sstatehash); + fork_states.push(current_state_ids); + + // We also add state after incoming event to the fork states + let mut state_after = state_at_incoming_event.clone(); + if let Some(state_key) = &incoming_pdu.state_key { + let shortstatekey = services().rooms.short.get_or_create_shortstatekey( + &incoming_pdu.kind.to_string().into(), + state_key, + )?; + + state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); + } + fork_states.push(state_after); + + let mut update_state = false; + // 14. Use state resolution to find new room state + let new_room_state = if fork_states.is_empty() { + panic!("State is empty"); + } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { + info!("State resolution trivial"); + // There was only one state, so it has to be the room's current state (because that is + // always included) + fork_states[0] + .iter() + .map(|(k, id)| { + services() + .rooms + .state_compressor + .compress_state_event(*k, id) + }) + .collect::<Result<_>>()? + } else { + info!("Loading auth chains"); + // We do need to force an update to this room's state + update_state = true; + + let mut auth_chain_sets = Vec::new(); + for state in &fork_states { + auth_chain_sets.push( + services() + .rooms + .auth_chain + .get_auth_chain( + room_id, + state.iter().map(|(_, id)| id.clone()).collect(), + ) + .await? + .collect(), + ); + } + + info!("Loading fork states"); + + let fork_states: Vec<_> = fork_states + .into_iter() + .map(|map| { + map.into_iter() + .filter_map(|(k, id)| { + services() + .rooms + .short + .get_statekey_from_short(k) + .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) + .ok() + }) + .collect::<StateMap<_>>() + }) + .collect(); + + info!("Resolving state"); + + let lock = services().globals.stateres_mutex.lock(); + let state = match state_res::resolve( + room_version_id, + &fork_states, + auth_chain_sets, + |id| { + let res = services().rooms.timeline.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }, + ) { + Ok(new_state) => new_state, + Err(_) => { + return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization")); + } + }; + + drop(lock); + + info!("State resolution done. Compressing state"); + + state + .into_iter() + .map(|((event_type, state_key), event_id)| { + let shortstatekey = services().rooms.short.get_or_create_shortstatekey( + &event_type.to_string().into(), + &state_key, + )?; + services() + .rooms + .state_compressor + .compress_state_event(shortstatekey, &event_id) + }) + .collect::<Result<_>>()? + }; + + // Set the new room state to the resolved state + if update_state { + info!("Forcing new room state"); + let (sstatehash, new, removed) = services() + .rooms + .state_compressor + .save_state(room_id, new_room_state)?; + services() + .rooms + .state + .force_state(room_id, sstatehash, new, removed, &state_lock) + .await?; + } + } + + info!("Appending pdu to timeline"); + extremities.insert(incoming_pdu.event_id.clone()); + + // Now that the event has passed all auth it is added into the timeline. + // We use the `state_at_event` instead of `state_after` so we accurately + // represent the state for this event. + + let pdu_id = services().rooms.timeline.append_incoming_pdu( + &incoming_pdu, + val, + extremities.iter().map(|e| (**e).to_owned()).collect(), + state_ids_compressed, + soft_fail, + &state_lock, + )?; + + info!("Appended incoming pdu"); + + // Event has passed all auth/stateres checks + drop(state_lock); + Ok(pdu_id) + } + + /// Find the event and auth it. Once the event is validated (steps 1 - 8) + /// it is appended to the outliers Tree. + /// + /// Returns pdu and if we fetched it over federation the raw json. + /// + /// a. Look in the main timeline (pduid_pdu tree) + /// b. Look at outlier pdu tree + /// c. Ask origin server over federation + /// d. TODO: Ask other servers over federation? + #[tracing::instrument(skip_all)] + pub(crate) fn fetch_and_handle_outliers<'a>( + &'a self, + origin: &'a ServerName, + events: &'a [Arc<EventId>], + create_event: &'a PduEvent, + room_id: &'a RoomId, + room_version_id: &'a RoomVersionId, + pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + ) -> AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>> + { + Box::pin(async move { + let back_off = |id| match services() + .globals + .bad_event_ratelimiter + .write() + .unwrap() + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; + + let mut pdus = vec![]; + for id in events { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&**id) + { + // Exponential backoff + let mut min_elapsed_duration = + Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", id); + continue; + } + } + + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { + trace!("Found {} in db", id); + pdus.push((local_pdu, None)); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::new(); + let mut events_all = HashSet::new(); + let mut i = 0; + while let Some(next_id) = todo_auth_events.pop() { + if events_all.contains(&next_id) { + continue; + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + + if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { + trace!("Found {} in db", id); + continue; + } + + info!("Fetching {} over federation.", next_id); + match services() + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + info!("Got {} over federation", next_id); + let (calculated_event_id, value) = + match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) { + Ok(t) => t, + Err(_) => { + back_off((*next_id).to_owned()); + continue; + } + }; + + if calculated_event_id != *next_id { + warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", + next_id, calculated_event_id, &res.pdu); + } + + if let Some(auth_events) = + value.get("auth_events").and_then(|c| c.as_array()) + { + for auth_event in auth_events { + if let Ok(auth_event) = + serde_json::from_value(auth_event.clone().into()) + { + let a: Arc<EventId> = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); + } + } + } else { + warn!("Auth event list invalid"); + } + + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + } + Err(_) => { + warn!("Failed to fetch event: {}", next_id); + back_off((*next_id).to_owned()); + } + } + } + + for (next_id, value) in events_in_reverse_order.iter().rev() { + match self + .handle_outlier_pdu( + origin, + create_event, + next_id, + room_id, + value.clone(), + pub_key_map, + ) + .await + { + Ok((pdu, json)) => { + if next_id == id { + pdus.push((pdu, Some(json))); + } + } + Err(e) => { + warn!("Authentication of event {} failed: {:?}", next_id, e); + back_off((**next_id).to_owned()); + } + } + } + } + pdus + }) + } + + async fn fetch_unknown_prev_events( + &self, + origin: &ServerName, + create_event: &PduEvent, + room_id: &RoomId, + room_version_id: &RoomVersionId, + pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + initial_set: Vec<Arc<EventId>>, + ) -> Result<( + Vec<Arc<EventId>>, + HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>, + )> { + let mut graph: HashMap<Arc<EventId>, _> = HashMap::new(); + let mut eventid_info = HashMap::new(); + let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set; + + let first_pdu_in_room = services() + .rooms + .timeline + .first_pdu_in_room(room_id)? + .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + + let mut amount = 0; + + while let Some(prev_event_id) = todo_outlier_stack.pop() { + if let Some((pdu, json_opt)) = self + .fetch_and_handle_outliers( + origin, + &[prev_event_id.clone()], + create_event, + room_id, + room_version_id, + pub_key_map, + ) + .await + .pop() + { + if amount > services().globals.max_fetch_prev_events() { + // Max limit reached + warn!("Max prev event limit reached!"); + graph.insert(prev_event_id.clone(), HashSet::new()); + continue; + } + + if let Some(json) = json_opt.or_else(|| { + services() + .rooms + .outlier + .get_outlier_pdu_json(&prev_event_id) + .ok() + .flatten() + }) { + if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { + amount += 1; + for prev_prev in &pdu.prev_events { + if !graph.contains_key(prev_prev) { + todo_outlier_stack.push(prev_prev.clone()); + } + } + + graph.insert( + prev_event_id.clone(), + pdu.prev_events.iter().cloned().collect(), + ); + } else { + // Time based check failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + + eventid_info.insert(prev_event_id.clone(), (pdu, json)); + } else { + // Get json failed, so this was not fetched over federation + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } else { + // Fetch and handle failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } + + let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + Ok(( + int!(0), + MilliSecondsSinceUnixEpoch( + eventid_info + .get(event_id) + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), + ), + )) + }) + .map_err(|_| Error::bad_database("Error sorting prev events"))?; + + Ok((sorted, eventid_info)) + } + + #[tracing::instrument(skip_all)] + pub(crate) async fn fetch_required_signing_keys( + &self, + event: &BTreeMap<String, CanonicalJsonValue>, + pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + ) -> Result<()> { + let signatures = event + .get("signatures") + .ok_or(Error::BadServerResponse( + "No signatures in server response pdu.", + ))? + .as_object() + .ok_or(Error::BadServerResponse( + "Invalid signatures object in server response pdu.", + ))?; + + // We go through all the signatures we see on the value and fetch the corresponding signing + // keys + for (signature_server, signature) in signatures { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; + + let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); + + let fetch_res = self + .fetch_signing_keys( + signature_server.as_str().try_into().map_err(|_| { + Error::BadServerResponse( + "Invalid servername in signatures of server response pdu.", + ) + })?, + signature_ids, + ) + .await; + + let keys = match fetch_res { + Ok(keys) => keys, + Err(_) => { + warn!("Signature verification failed: Could not fetch signing key.",); + continue; + } + }; + + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(signature_server.clone(), keys); + } + + Ok(()) + } + + // Gets a list of servers for which we don't have the signing key yet. We go over + // the PDUs and either cache the key or add it to the list that needs to be retrieved. + fn get_server_keys_from_cache( + &self, + pdu: &RawJsonValue, + servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>, + room_version: &RoomVersionId, + pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, Base64>>>, + ) -> Result<()> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&value, room_version) + .expect("ruma can calculate reference hashes") + ); + let event_id = <&EventId>::try_from(event_id.as_str()) + .expect("ruma's reference hashes are valid event ids"); + + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(event_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } + + let signatures = value + .get("signatures") + .ok_or(Error::BadServerResponse( + "No signatures in server response pdu.", + ))? + .as_object() + .ok_or(Error::BadServerResponse( + "Invalid signatures object in server response pdu.", + ))?; + + for (signature_server, signature) in signatures { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; + + let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); + + let contains_all_ids = |keys: &BTreeMap<String, Base64>| { + signature_ids.iter().all(|id| keys.contains_key(id)) + }; + + let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|_| { + Error::BadServerResponse("Invalid servername in signatures of server response pdu.") + })?; + + if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { + continue; + } + + trace!("Loading signing keys for {}", origin); + + let result: BTreeMap<_, _> = services() + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + + if !contains_all_ids(&result) { + trace!("Signing key not loaded for {}", origin); + servers.insert(origin.to_owned(), BTreeMap::new()); + } + + pub_key_map.insert(origin.to_string(), result); + } + + Ok(()) + } + + pub(crate) async fn fetch_join_signing_keys( + &self, + event: &create_join_event::v2::Response, + room_version: &RoomVersionId, + pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + ) -> Result<()> { + let mut servers: BTreeMap< + OwnedServerName, + BTreeMap<OwnedServerSigningKeyId, QueryCriteria>, + > = BTreeMap::new(); + + { + let mut pkm = pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))?; + + // Try to fetch keys, failure is okay + // Servers we couldn't find in the cache will be added to `servers` + for pdu in &event.room_state.state { + let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); + } + for pdu in &event.room_state.auth_chain { + let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); + } + + drop(pkm); + } + + if servers.is_empty() { + // We had all keys locally + return Ok(()); + } + + for server in services().globals.trusted_servers() { + trace!("Asking batch signing keys from trusted server {}", server); + if let Ok(keys) = services() + .sending + .send_federation_request( + server, + get_remote_server_keys_batch::v2::Request { + server_keys: servers.clone(), + }, + ) + .await + { + trace!("Got signing keys: {:?}", keys); + let mut pkm = pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))?; + for k in keys.server_keys { + let k = match k.deserialize() { + Ok(key) => key, + Err(e) => { + warn!( + "Received error {} while fetching keys from trusted server {}", + e, server + ); + warn!("{}", k.into_json()); + continue; + } + }; + + // TODO: Check signature from trusted server? + servers.remove(&k.server_name); + + let result = services() + .globals + .add_signing_key(&k.server_name, k.clone())? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect::<BTreeMap<_, _>>(); + + pkm.insert(k.server_name.to_string(), result); + } + } + + if servers.is_empty() { + return Ok(()); + } + } + + let mut futures: FuturesUnordered<_> = servers + .into_keys() + .map(|server| async move { + ( + services() + .sending + .send_federation_request(&server, get_server_keys::v2::Request::new()) + .await, + server, + ) + }) + .collect(); + + while let Some(result) = futures.next().await { + if let (Ok(get_keys_response), origin) = result { + let result: BTreeMap<_, _> = services() + .globals + .add_signing_key(&origin, get_keys_response.server_key.deserialize().unwrap())? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(origin.to_string(), result); + } + } + + Ok(()) + } + + /// Returns Ok if the acl allows the server + pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + let acl_event = match services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomServerAcl, + "", + )? { + Some(acl) => acl, + None => return Ok(()), + }; + + let acl_event_content: RoomServerAclEventContent = + match serde_json::from_str(acl_event.content.get()) { + Ok(content) => content, + Err(_) => { + warn!("Invalid ACL event"); + return Ok(()); + } + }; + + if acl_event_content.is_allowed(server_name) { + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::Forbidden, + "Server was denied by ACL", + )) + } + } + + /// Search the DB for the signing keys of the given server, if we don't have them + /// fetch them from the server and save to our DB. + #[tracing::instrument(skip_all)] + pub async fn fetch_signing_keys( + &self, + origin: &ServerName, + signature_ids: Vec<String>, + ) -> Result<BTreeMap<String, Base64>> { + let contains_all_ids = + |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); + + let permit = services() + .globals + .servername_ratelimiter + .read() + .unwrap() + .get(origin) + .map(|s| Arc::clone(s).acquire_owned()); + + let permit = match permit { + Some(p) => p, + None => { + let mut write = services().globals.servername_ratelimiter.write().unwrap(); + let s = Arc::clone( + write + .entry(origin.to_owned()) + .or_insert_with(|| Arc::new(Semaphore::new(1))), + ); + + s.acquire_owned() + } + } + .await; + + let back_off = |id| match services() + .globals + .bad_signature_ratelimiter + .write() + .unwrap() + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; + + if let Some((time, tries)) = services() + .globals + .bad_signature_ratelimiter + .read() + .unwrap() + .get(&signature_ids) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {:?}", signature_ids); + return Err(Error::BadServerResponse("bad signature, still backing off")); + } + } + + trace!("Loading signing keys for {}", origin); + + let mut result: BTreeMap<_, _> = services() + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + + if contains_all_ids(&result) { + return Ok(result); + } + + debug!("Fetching signing keys for {} over federation", origin); + + if let Some(server_key) = services() + .sending + .send_federation_request(origin, get_server_keys::v2::Request::new()) + .await + .ok() + .and_then(|resp| resp.server_key.deserialize().ok()) + { + services() + .globals + .add_signing_key(origin, server_key.clone())?; + + result.extend( + server_key + .verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + server_key + .old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + + if contains_all_ids(&result) { + return Ok(result); + } + } + + for server in services().globals.trusted_servers() { + debug!("Asking {} for {}'s signing key", server, origin); + if let Some(server_keys) = services() + .sending + .send_federation_request( + server, + get_remote_server_keys::v2::Request::new( + origin.to_owned(), + MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + .checked_add(Duration::from_secs(3600)) + .expect("SystemTime to large"), + ) + .expect("time is valid"), + ), + ) + .await + .ok() + .map(|resp| { + resp.server_keys + .into_iter() + .filter_map(|e| e.deserialize().ok()) + .collect::<Vec<_>>() + }) + { + trace!("Got signing keys: {:?}", server_keys); + for k in server_keys { + services().globals.add_signing_key(origin, k.clone())?; + result.extend( + k.verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + k.old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + } + + if contains_all_ids(&result) { + return Ok(result); + } + } + } + + drop(permit); + + back_off(signature_ids); + + warn!("Failed to find public key for server: {}", origin); + Err(Error::BadServerResponse( + "Failed to find public key for server", + )) + } +} diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs new file mode 100644 index 0000000..9af8e21 --- /dev/null +++ b/src/service/rooms/lazy_loading/data.rs @@ -0,0 +1,27 @@ +use crate::Result; +use ruma::{DeviceId, RoomId, UserId}; + +pub trait Data: Send + Sync { + fn lazy_load_was_sent_before( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ll_user: &UserId, + ) -> Result<bool>; + + fn lazy_load_confirm_delivery( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + confirmed_user_ids: &mut dyn Iterator<Item = &UserId>, + ) -> Result<()>; + + fn lazy_load_reset( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ) -> Result<()>; +} diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs new file mode 100644 index 0000000..701a734 --- /dev/null +++ b/src/service/rooms/lazy_loading/mod.rs @@ -0,0 +1,88 @@ +mod data; +use std::{ + collections::{HashMap, HashSet}, + sync::Mutex, +}; + +pub use data::Data; +use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, + + pub lazy_load_waiting: + Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, u64), HashSet<OwnedUserId>>>, +} + +impl Service { + #[tracing::instrument(skip(self))] + pub fn lazy_load_was_sent_before( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ll_user: &UserId, + ) -> Result<bool> { + self.db + .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) + } + + #[tracing::instrument(skip(self))] + pub fn lazy_load_mark_sent( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + lazy_load: HashSet<OwnedUserId>, + count: u64, + ) { + self.lazy_load_waiting.lock().unwrap().insert( + ( + user_id.to_owned(), + device_id.to_owned(), + room_id.to_owned(), + count, + ), + lazy_load, + ); + } + + #[tracing::instrument(skip(self))] + pub fn lazy_load_confirm_delivery( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + since: u64, + ) -> Result<()> { + if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( + user_id.to_owned(), + device_id.to_owned(), + room_id.to_owned(), + since, + )) { + self.db.lazy_load_confirm_delivery( + user_id, + device_id, + room_id, + &mut user_ids.iter().map(|u| &**u), + )?; + } else { + // Ignore + } + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn lazy_load_reset( + &self, + user_id: &UserId, + device_id: &DeviceId, + room_id: &RoomId, + ) -> Result<()> { + self.db.lazy_load_reset(user_id, device_id, room_id) + } +} diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs new file mode 100644 index 0000000..339db57 --- /dev/null +++ b/src/service/rooms/metadata/data.rs @@ -0,0 +1,9 @@ +use crate::Result; +use ruma::{OwnedRoomId, RoomId}; + +pub trait Data: Send + Sync { + fn exists(&self, room_id: &RoomId) -> Result<bool>; + fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; + fn is_disabled(&self, room_id: &RoomId) -> Result<bool>; + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; +} diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs new file mode 100644 index 0000000..d188469 --- /dev/null +++ b/src/service/rooms/metadata/mod.rs @@ -0,0 +1,30 @@ +mod data; + +pub use data::Data; +use ruma::{OwnedRoomId, RoomId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Checks if a room exists. + #[tracing::instrument(skip(self))] + pub fn exists(&self, room_id: &RoomId) -> Result<bool> { + self.db.exists(room_id) + } + + pub fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { + self.db.iter_ids() + } + + pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { + self.db.is_disabled(room_id) + } + + pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + self.db.disable_room(room_id, disabled) + } +} diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs new file mode 100644 index 0000000..8956e4d --- /dev/null +++ b/src/service/rooms/mod.rs @@ -0,0 +1,57 @@ +pub mod alias; +pub mod auth_chain; +pub mod directory; +pub mod edus; +pub mod event_handler; +pub mod lazy_loading; +pub mod metadata; +pub mod outlier; +pub mod pdu_metadata; +pub mod search; +pub mod short; +pub mod state; +pub mod state_accessor; +pub mod state_cache; +pub mod state_compressor; +pub mod timeline; +pub mod user; + +pub trait Data: + alias::Data + + auth_chain::Data + + directory::Data + + edus::Data + + lazy_loading::Data + + metadata::Data + + outlier::Data + + pdu_metadata::Data + + search::Data + + short::Data + + state::Data + + state_accessor::Data + + state_cache::Data + + state_compressor::Data + + timeline::Data + + user::Data +{ +} + +pub struct Service { + pub alias: alias::Service, + pub auth_chain: auth_chain::Service, + pub directory: directory::Service, + pub edus: edus::Service, + pub event_handler: event_handler::Service, + pub lazy_loading: lazy_loading::Service, + pub metadata: metadata::Service, + pub outlier: outlier::Service, + pub pdu_metadata: pdu_metadata::Service, + pub search: search::Service, + pub short: short::Service, + pub state: state::Service, + pub state_accessor: state_accessor::Service, + pub state_cache: state_cache::Service, + pub state_compressor: state_compressor::Service, + pub timeline: timeline::Service, + pub user: user::Service, +} diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs new file mode 100644 index 0000000..0ed521d --- /dev/null +++ b/src/service/rooms/outlier/data.rs @@ -0,0 +1,9 @@ +use ruma::{CanonicalJsonObject, EventId}; + +use crate::{PduEvent, Result}; + +pub trait Data: Send + Sync { + fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>; + fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>; + fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; +} diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs new file mode 100644 index 0000000..dae41e4 --- /dev/null +++ b/src/service/rooms/outlier/mod.rs @@ -0,0 +1,28 @@ +mod data; + +pub use data::Data; +use ruma::{CanonicalJsonObject, EventId}; + +use crate::{PduEvent, Result}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Returns the pdu from the outlier tree. + pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { + self.db.get_outlier_pdu_json(event_id) + } + + /// Returns the pdu from the outlier tree. + pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { + self.db.get_outlier_pdu(event_id) + } + + /// Append the PDU as an outlier. + #[tracing::instrument(skip(self, pdu))] + pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + self.db.add_pdu_outlier(event_id, pdu) + } +} diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs new file mode 100644 index 0000000..b157938 --- /dev/null +++ b/src/service/rooms/pdu_metadata/data.rs @@ -0,0 +1,11 @@ +use std::sync::Arc; + +use crate::Result; +use ruma::{EventId, RoomId}; + +pub trait Data: Send + Sync { + fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()>; + fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool>; + fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; + fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool>; +} diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs new file mode 100644 index 0000000..b816678 --- /dev/null +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -0,0 +1,33 @@ +mod data; +use std::sync::Arc; + +pub use data::Data; +use ruma::{EventId, RoomId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + #[tracing::instrument(skip(self, room_id, event_ids))] + pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { + self.db.mark_as_referenced(room_id, event_ids) + } + + #[tracing::instrument(skip(self))] + pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { + self.db.is_event_referenced(room_id, event_id) + } + + #[tracing::instrument(skip(self))] + pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { + self.db.mark_event_soft_failed(event_id) + } + + #[tracing::instrument(skip(self))] + pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { + self.db.is_event_soft_failed(event_id) + } +} diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs new file mode 100644 index 0000000..6eef38f --- /dev/null +++ b/src/service/rooms/search/data.rs @@ -0,0 +1,12 @@ +use crate::Result; +use ruma::RoomId; + +pub trait Data: Send + Sync { + fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; + + fn search_pdus<'a>( + &'a self, + room_id: &RoomId, + search_string: &str, + ) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>; +} diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs new file mode 100644 index 0000000..b6f35e7 --- /dev/null +++ b/src/service/rooms/search/mod.rs @@ -0,0 +1,26 @@ +mod data; + +pub use data::Data; + +use crate::Result; +use ruma::RoomId; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + #[tracing::instrument(skip(self))] + pub fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + self.db.index_pdu(shortroomid, pdu_id, message_body) + } + + #[tracing::instrument(skip(self))] + pub fn search_pdus<'a>( + &'a self, + room_id: &RoomId, + search_string: &str, + ) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> { + self.db.search_pdus(room_id, search_string) + } +} diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs new file mode 100644 index 0000000..652c525 --- /dev/null +++ b/src/service/rooms/short/data.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +use crate::Result; +use ruma::{events::StateEventType, EventId, RoomId}; + +pub trait Data: Send + Sync { + fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>; + + fn get_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<u64>>; + + fn get_or_create_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<u64>; + + fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>>; + + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; + + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; + + fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>; + + fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64>; +} diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs new file mode 100644 index 0000000..45fadd7 --- /dev/null +++ b/src/service/rooms/short/mod.rs @@ -0,0 +1,54 @@ +mod data; +use std::sync::Arc; + +pub use data::Data; +use ruma::{events::StateEventType, EventId, RoomId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> { + self.db.get_or_create_shorteventid(event_id) + } + + pub fn get_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<u64>> { + self.db.get_shortstatekey(event_type, state_key) + } + + pub fn get_or_create_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<u64> { + self.db.get_or_create_shortstatekey(event_type, state_key) + } + + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { + self.db.get_eventid_from_short(shorteventid) + } + + pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + self.db.get_statekey_from_short(shortstatekey) + } + + /// Returns (shortstatehash, already_existed) + pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + self.db.get_or_create_shortstatehash(state_hash) + } + + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.db.get_shortroomid(room_id) + } + + pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> { + self.db.get_or_create_shortroomid(room_id) + } +} diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs new file mode 100644 index 0000000..96116b0 --- /dev/null +++ b/src/service/rooms/state/data.rs @@ -0,0 +1,31 @@ +use crate::Result; +use ruma::{EventId, OwnedEventId, RoomId}; +use std::{collections::HashSet, sync::Arc}; +use tokio::sync::MutexGuard; + +pub trait Data: Send + Sync { + /// Returns the last state hash key added to the db for the given room. + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>>; + + /// Set the state hash to a new version, but does not update state_cache. + fn set_room_state( + &self, + room_id: &RoomId, + new_shortstatehash: u64, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()>; + + /// Associates a state with an event. + fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()>; + + /// Returns all events we would send as the prev_events of the next event. + fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>>; + + /// Replace the forward extremities of the room. + fn set_forward_extremities( + &self, + room_id: &RoomId, + event_ids: Vec<OwnedEventId>, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()>; +} diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs new file mode 100644 index 0000000..3072b80 --- /dev/null +++ b/src/service/rooms/state/mod.rs @@ -0,0 +1,421 @@ +mod data; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +pub use data::Data; +use ruma::{ + events::{ + room::{create::RoomCreateEventContent, member::MembershipState}, + AnyStrippedStateEvent, RoomEventType, StateEventType, + }, + serde::Raw, + state_res::{self, StateMap}, + EventId, OwnedEventId, RoomId, RoomVersionId, UserId, +}; +use serde::Deserialize; +use tokio::sync::MutexGuard; +use tracing::warn; + +use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; + +use super::state_compressor::CompressedStateEvent; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Set the room to the given statehash and update caches. + pub async fn force_state( + &self, + room_id: &RoomId, + shortstatehash: u64, + statediffnew: HashSet<CompressedStateEvent>, + _statediffremoved: HashSet<CompressedStateEvent>, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + for event_id in statediffnew.into_iter().filter_map(|new| { + services() + .rooms + .state_compressor + .parse_compressed_state_event(&new) + .ok() + .map(|(_, id)| id) + }) { + let pdu = match services().rooms.timeline.get_pdu_json(&event_id)? { + Some(pdu) => pdu, + None => continue, + }; + + if pdu.get("type").and_then(|val| val.as_str()) != Some("m.room.member") { + continue; + } + + let pdu: PduEvent = match serde_json::from_str( + &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), + ) { + Ok(pdu) => pdu, + Err(_) => continue, + }; + + #[derive(Deserialize)] + struct ExtractMembership { + membership: MembershipState, + } + + let membership = match serde_json::from_str::<ExtractMembership>(pdu.content.get()) { + Ok(e) => e.membership, + Err(_) => continue, + }; + + let state_key = match pdu.state_key { + Some(k) => k, + None => continue, + }; + + let user_id = match UserId::parse(state_key) { + Ok(id) => id, + Err(_) => continue, + }; + + services().rooms.state_cache.update_membership( + room_id, + &user_id, + membership, + &pdu.sender, + None, + false, + )?; + } + + services().rooms.state_cache.update_joined_count(room_id)?; + + self.db + .set_room_state(room_id, shortstatehash, state_lock)?; + + Ok(()) + } + + /// Generates a new StateHash and associates it with the incoming event. + /// + /// This adds all current state events (not including the incoming event) + /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. + #[tracing::instrument(skip(self, state_ids_compressed))] + pub fn set_event_state( + &self, + event_id: &EventId, + room_id: &RoomId, + state_ids_compressed: HashSet<CompressedStateEvent>, + ) -> Result<u64> { + let shorteventid = services() + .rooms + .short + .get_or_create_shorteventid(event_id)?; + + let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; + + let state_hash = calculate_hash( + &state_ids_compressed + .iter() + .map(|s| &s[..]) + .collect::<Vec<_>>(), + ); + + let (shortstatehash, already_existed) = services() + .rooms + .short + .get_or_create_shortstatehash(&state_hash)?; + + if !already_existed { + let states_parents = previous_shortstatehash.map_or_else( + || Ok(Vec::new()), + |p| { + services() + .rooms + .state_compressor + .load_shortstatehash_info(p) + }, + )?; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew: HashSet<_> = state_ids_compressed + .difference(&parent_stateinfo.1) + .copied() + .collect(); + + let statediffremoved: HashSet<_> = parent_stateinfo + .1 + .difference(&state_ids_compressed) + .copied() + .collect(); + + (statediffnew, statediffremoved) + } else { + (state_ids_compressed, HashSet::new()) + }; + services().rooms.state_compressor.save_state_from_diff( + shortstatehash, + statediffnew, + statediffremoved, + 1_000_000, // high number because no state will be based on this one + states_parents, + )?; + } + + self.db.set_event_state(shorteventid, shortstatehash)?; + + Ok(shortstatehash) + } + + /// Generates a new StateHash and associates it with the incoming event. + /// + /// This adds all current state events (not including the incoming event) + /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. + #[tracing::instrument(skip(self, new_pdu))] + pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { + let shorteventid = services() + .rooms + .short + .get_or_create_shorteventid(&new_pdu.event_id)?; + + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; + + if let Some(p) = previous_shortstatehash { + self.db.set_event_state(shorteventid, p)?; + } + + if let Some(state_key) = &new_pdu.state_key { + let states_parents = previous_shortstatehash.map_or_else( + || Ok(Vec::new()), + |p| { + services() + .rooms + .state_compressor + .load_shortstatehash_info(p) + }, + )?; + + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + + let new = services() + .rooms + .state_compressor + .compress_state_event(shortstatekey, &new_pdu.event_id)?; + + let replaces = states_parents + .last() + .map(|info| { + info.1 + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + }) + .unwrap_or_default(); + + if Some(&new) == replaces { + return Ok(previous_shortstatehash.expect("must exist")); + } + + // TODO: statehash with deterministic inputs + let shortstatehash = services().globals.next_count()?; + + let mut statediffnew = HashSet::new(); + statediffnew.insert(new); + + let mut statediffremoved = HashSet::new(); + if let Some(replaces) = replaces { + statediffremoved.insert(*replaces); + } + + services().rooms.state_compressor.save_state_from_diff( + shortstatehash, + statediffnew, + statediffremoved, + 2, + states_parents, + )?; + + Ok(shortstatehash) + } else { + Ok(previous_shortstatehash.expect("first event in room must be a state event")) + } + } + + #[tracing::instrument(skip(self, invite_event))] + pub fn calculate_invite_state( + &self, + invite_event: &PduEvent, + ) -> Result<Vec<Raw<AnyStrippedStateEvent>>> { + let mut state = Vec::new(); + // Add recommended events + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomCreate, + "", + )? { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomJoinRules, + "", + )? { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomCanonicalAlias, + "", + )? { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomAvatar, + "", + )? { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomName, + "", + )? { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomMember, + invite_event.sender.as_str(), + )? { + state.push(e.to_stripped_state_event()); + } + + state.push(invite_event.to_stripped_state_event()); + Ok(state) + } + + /// Set the state hash to a new version, but does not update state_cache. + #[tracing::instrument(skip(self))] + pub fn set_room_state( + &self, + room_id: &RoomId, + shortstatehash: u64, + mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.db.set_room_state(room_id, shortstatehash, mutex_lock) + } + + /// Returns the room's version. + #[tracing::instrument(skip(self))] + pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { + let create_event = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomCreate, + "", + )?; + + let create_event_content: Option<RoomCreateEventContent> = create_event + .as_ref() + .map(|create_event| { + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::bad_database("Invalid create event in db.") + }) + }) + .transpose()?; + let room_version = create_event_content + .map(|create_event| create_event.room_version) + .ok_or(Error::BadDatabase("Invalid room version"))?; + Ok(room_version) + } + + pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.db.get_room_shortstatehash(room_id) + } + + pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { + self.db.get_forward_extremities(room_id) + } + + pub fn set_forward_extremities( + &self, + room_id: &RoomId, + event_ids: Vec<OwnedEventId>, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.db + .set_forward_extremities(room_id, event_ids, state_lock) + } + + /// This fetches auth events from the current state. + #[tracing::instrument(skip(self))] + pub fn get_auth_events( + &self, + room_id: &RoomId, + kind: &RoomEventType, + sender: &UserId, + state_key: Option<&str>, + content: &serde_json::value::RawValue, + ) -> Result<StateMap<Arc<PduEvent>>> { + let shortstatehash = if let Some(current_shortstatehash) = + services().rooms.state.get_room_shortstatehash(room_id)? + { + current_shortstatehash + } else { + return Ok(HashMap::new()); + }; + + let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content) + .expect("content is a valid JSON object"); + + let mut sauthevents = auth_events + .into_iter() + .filter_map(|(event_type, state_key)| { + services() + .rooms + .short + .get_shortstatekey(&event_type.to_string().into(), &state_key) + .ok() + .flatten() + .map(|s| (s, (event_type, state_key))) + }) + .collect::<HashMap<_, _>>(); + + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + + Ok(full_state + .into_iter() + .filter_map(|compressed| { + services() + .rooms + .state_compressor + .parse_compressed_state_event(&compressed) + .ok() + }) + .filter_map(|(shortstatekey, event_id)| { + sauthevents.remove(&shortstatekey).map(|k| (k, event_id)) + }) + .filter_map(|(k, event_id)| { + services() + .rooms + .timeline + .get_pdu(&event_id) + .ok() + .flatten() + .map(|pdu| (k, pdu)) + }) + .collect()) + } +} diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs new file mode 100644 index 0000000..f3ae3c2 --- /dev/null +++ b/src/service/rooms/state_accessor/data.rs @@ -0,0 +1,59 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_trait::async_trait; +use ruma::{events::StateEventType, EventId, RoomId}; + +use crate::{PduEvent, Result}; + +#[async_trait] +pub trait Data: Send + Sync { + /// Builds a StateMap by iterating over all keys that start + /// with state_hash, this gives the full state for the given state_hash. + async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>>; + + async fn state_full( + &self, + shortstatehash: u64, + ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get_id( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<EventId>>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn state_get( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<PduEvent>>>; + + /// Returns the state hash for this pdu. + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>>; + + /// Returns the full room state. + async fn room_state_full( + &self, + room_id: &RoomId, + ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<EventId>>>; + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<PduEvent>>>; +} diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs new file mode 100644 index 0000000..87d9936 --- /dev/null +++ b/src/service/rooms/state_accessor/mod.rs @@ -0,0 +1,84 @@ +mod data; +use std::{collections::HashMap, sync::Arc}; + +pub use data::Data; +use ruma::{events::StateEventType, EventId, RoomId}; + +use crate::{PduEvent, Result}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Builds a StateMap by iterating over all keys that start + /// with state_hash, this gives the full state for the given state_hash. + #[tracing::instrument(skip(self))] + pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { + self.db.state_full_ids(shortstatehash).await + } + + pub async fn state_full( + &self, + shortstatehash: u64, + ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { + self.db.state_full(shortstatehash).await + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn state_get_id( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<EventId>>> { + self.db.state_get_id(shortstatehash, event_type, state_key) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + pub fn state_get( + &self, + shortstatehash: u64, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<PduEvent>>> { + self.db.state_get(shortstatehash, event_type, state_key) + } + + /// Returns the state hash for this pdu. + pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { + self.db.pdu_shortstatehash(event_id) + } + + /// Returns the full room state. + #[tracing::instrument(skip(self))] + pub async fn room_state_full( + &self, + room_id: &RoomId, + ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { + self.db.room_state_full(room_id).await + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<EventId>>> { + self.db.room_state_get_id(room_id, event_type, state_key) + } + + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<Arc<PduEvent>>> { + self.db.room_state_get(room_id, event_type, state_key) + } +} diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs new file mode 100644 index 0000000..d8bb4a4 --- /dev/null +++ b/src/service/rooms/state_cache/data.rs @@ -0,0 +1,111 @@ +use std::{collections::HashSet, sync::Arc}; + +use crate::Result; +use ruma::{ + events::{AnyStrippedStateEvent, AnySyncStateEvent}, + serde::Raw, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, +}; + +pub trait Data: Send + Sync { + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_invited( + &self, + user_id: &UserId, + room_id: &RoomId, + last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, + ) -> Result<()>; + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + + fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; + + fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>>; + + fn appservice_in_room( + &self, + room_id: &RoomId, + appservice: &(String, serde_yaml::Value), + ) -> Result<bool>; + + /// Makes a user forget a room. + fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; + + /// Returns an iterator of all servers participating in this room. + fn room_servers<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>; + + fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool>; + + /// Returns an iterator of all rooms a server participates in (as far as we know). + fn server_rooms<'a>( + &'a self, + server: &ServerName, + ) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; + + /// Returns an iterator over all joined members of a room. + fn room_members<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>; + + fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>>; + + fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>>; + + /// Returns an iterator over all User IDs who ever joined a room. + fn room_useroncejoined<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>; + + /// Returns an iterator over all invited members of a room. + fn room_members_invited<'a>( + &'a self, + room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>; + + fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>; + + fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>; + + /// Returns an iterator over all rooms this user joined. + fn rooms_joined<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; + + /// Returns an iterator over all rooms a user was invited to. + fn rooms_invited<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; + + fn invite_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>>; + + fn left_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>>; + + /// Returns an iterator over all rooms a user left. + fn rooms_left<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>; + + fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; + + fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; + + fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; + + fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; +} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs new file mode 100644 index 0000000..32afdd4 --- /dev/null +++ b/src/service/rooms/state_cache/mod.rs @@ -0,0 +1,348 @@ +mod data; +use std::{collections::HashSet, sync::Arc}; + +pub use data::Data; + +use ruma::{ + events::{ + direct::DirectEvent, + ignored_user_list::IgnoredUserListEvent, + room::{create::RoomCreateEventContent, member::MembershipState}, + AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, + RoomAccountDataEventType, StateEventType, + }, + serde::Raw, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, +}; + +use crate::{services, Error, Result}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Update current membership data. + #[tracing::instrument(skip(self, last_state))] + pub fn update_membership( + &self, + room_id: &RoomId, + user_id: &UserId, + membership: MembershipState, + sender: &UserId, + last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, + update_joined_count: bool, + ) -> Result<()> { + // Keep track what remote users exist by adding them as "deactivated" users + if user_id.server_name() != services().globals.server_name() { + services().users.create(user_id, None)?; + // TODO: displayname, avatar url + } + + match &membership { + MembershipState::Join => { + // Check if the user never joined this room + if !self.once_joined(user_id, room_id)? { + // Add the user ID to the join list then + self.db.mark_as_once_joined(user_id, room_id)?; + + // Check if the room has a predecessor + if let Some(predecessor) = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")? + .and_then(|create| serde_json::from_str(create.content.get()).ok()) + .and_then(|content: RoomCreateEventContent| content.predecessor) + { + // Copy user settings from predecessor to the current room: + // - Push rules + // + // TODO: finish this once push rules are implemented. + // + // let mut push_rules_event_content: PushRulesEvent = account_data + // .get( + // None, + // user_id, + // EventType::PushRules, + // )?; + // + // NOTE: find where `predecessor.room_id` match + // and update to `room_id`. + // + // account_data + // .update( + // None, + // user_id, + // EventType::PushRules, + // &push_rules_event_content, + // globals, + // ) + // .ok(); + + // Copy old tags to new room + if let Some(tag_event) = services() + .account_data + .get( + Some(&predecessor.room_id), + user_id, + RoomAccountDataEventType::Tag, + )? + .map(|event| { + serde_json::from_str(event.get()).map_err(|_| { + Error::bad_database("Invalid account data event in db.") + }) + }) + { + services() + .account_data + .update( + Some(room_id), + user_id, + RoomAccountDataEventType::Tag, + &tag_event?, + ) + .ok(); + }; + + // Copy direct chat flag + if let Some(direct_event) = services() + .account_data + .get( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + )? + .map(|event| { + serde_json::from_str::<DirectEvent>(event.get()).map_err(|_| { + Error::bad_database("Invalid account data event in db.") + }) + }) + { + let mut direct_event = direct_event?; + let mut room_ids_updated = false; + + for room_ids in direct_event.content.0.values_mut() { + if room_ids.iter().any(|r| r == &predecessor.room_id) { + room_ids.push(room_id.to_owned()); + room_ids_updated = true; + } + } + + if room_ids_updated { + services().account_data.update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event) + .expect("to json always works"), + )?; + } + }; + } + } + + self.db.mark_as_joined(user_id, room_id)?; + } + MembershipState::Invite => { + // We want to know if the sender is ignored by the receiver + let is_ignored = services() + .account_data + .get( + None, // Ignored users are in global account data + user_id, // Receiver + GlobalAccountDataEventType::IgnoredUserList + .to_string() + .into(), + )? + .map(|event| { + serde_json::from_str::<IgnoredUserListEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db.")) + }) + .transpose()? + .map_or(false, |ignored| { + ignored + .content + .ignored_users + .iter() + .any(|(user, _details)| user == sender) + }); + + if is_ignored { + return Ok(()); + } + + self.db.mark_as_invited(user_id, room_id, last_state)?; + } + MembershipState::Leave | MembershipState::Ban => { + self.db.mark_as_left(user_id, room_id)?; + } + _ => {} + } + + if update_joined_count { + self.update_joined_count(room_id)?; + } + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id))] + pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { + self.db.update_joined_count(room_id) + } + + #[tracing::instrument(skip(self, room_id))] + pub fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> { + self.db.get_our_real_users(room_id) + } + + #[tracing::instrument(skip(self, room_id, appservice))] + pub fn appservice_in_room( + &self, + room_id: &RoomId, + appservice: &(String, serde_yaml::Value), + ) -> Result<bool> { + self.db.appservice_in_room(room_id, appservice) + } + + /// Makes a user forget a room. + #[tracing::instrument(skip(self))] + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + self.db.forget(room_id, user_id) + } + + /// Returns an iterator of all servers participating in this room. + #[tracing::instrument(skip(self))] + pub fn room_servers<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<OwnedServerName>> + 'a { + self.db.room_servers(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn server_in_room<'a>(&'a self, server: &ServerName, room_id: &RoomId) -> Result<bool> { + self.db.server_in_room(server, room_id) + } + + /// Returns an iterator of all rooms a server participates in (as far as we know). + #[tracing::instrument(skip(self))] + pub fn server_rooms<'a>( + &'a self, + server: &ServerName, + ) -> impl Iterator<Item = Result<OwnedRoomId>> + 'a { + self.db.server_rooms(server) + } + + /// Returns an iterator over all joined members of a room. + #[tracing::instrument(skip(self))] + pub fn room_members<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { + self.db.room_members(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.db.room_joined_count(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.db.room_invited_count(room_id) + } + + /// Returns an iterator over all User IDs who ever joined a room. + #[tracing::instrument(skip(self))] + pub fn room_useroncejoined<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { + self.db.room_useroncejoined(room_id) + } + + /// Returns an iterator over all invited members of a room. + #[tracing::instrument(skip(self))] + pub fn room_members_invited<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { + self.db.room_members_invited(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { + self.db.get_invite_count(room_id, user_id) + } + + #[tracing::instrument(skip(self))] + pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { + self.db.get_left_count(room_id, user_id) + } + + /// Returns an iterator over all rooms this user joined. + #[tracing::instrument(skip(self))] + pub fn rooms_joined<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<OwnedRoomId>> + 'a { + self.db.rooms_joined(user_id) + } + + /// Returns an iterator over all rooms a user was invited to. + #[tracing::instrument(skip(self))] + pub fn rooms_invited<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a { + self.db.rooms_invited(user_id) + } + + #[tracing::instrument(skip(self))] + pub fn invite_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { + self.db.invite_state(user_id, room_id) + } + + #[tracing::instrument(skip(self))] + pub fn left_state( + &self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { + self.db.left_state(user_id, room_id) + } + + /// Returns an iterator over all rooms a user left. + #[tracing::instrument(skip(self))] + pub fn rooms_left<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a { + self.db.rooms_left(user_id) + } + + #[tracing::instrument(skip(self))] + pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + self.db.once_joined(user_id, room_id) + } + + #[tracing::instrument(skip(self))] + pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + self.db.is_joined(user_id, room_id) + } + + #[tracing::instrument(skip(self))] + pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + self.db.is_invited(user_id, room_id) + } + + #[tracing::instrument(skip(self))] + pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + self.db.is_left(user_id, room_id) + } +} diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs new file mode 100644 index 0000000..ce164c6 --- /dev/null +++ b/src/service/rooms/state_compressor/data.rs @@ -0,0 +1,15 @@ +use std::collections::HashSet; + +use super::CompressedStateEvent; +use crate::Result; + +pub struct StateDiff { + pub parent: Option<u64>, + pub added: HashSet<CompressedStateEvent>, + pub removed: HashSet<CompressedStateEvent>, +} + +pub trait Data: Send + Sync { + fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff>; + fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; +} diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs new file mode 100644 index 0000000..356f32c --- /dev/null +++ b/src/service/rooms/state_compressor/mod.rs @@ -0,0 +1,310 @@ +pub mod data; +use std::{ + collections::HashSet, + mem::size_of, + sync::{Arc, Mutex}, +}; + +pub use data::Data; +use lru_cache::LruCache; +use ruma::{EventId, RoomId}; + +use crate::{services, utils, Result}; + +use self::data::StateDiff; + +pub struct Service { + pub db: &'static dyn Data, + + pub stateinfo_cache: Mutex< + LruCache< + u64, + Vec<( + u64, // sstatehash + HashSet<CompressedStateEvent>, // full state + HashSet<CompressedStateEvent>, // added + HashSet<CompressedStateEvent>, // removed + )>, + >, + >, +} + +pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()]; + +impl Service { + /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. + #[tracing::instrument(skip(self))] + pub fn load_shortstatehash_info( + &self, + shortstatehash: u64, + ) -> Result< + Vec<( + u64, // sstatehash + HashSet<CompressedStateEvent>, // full state + HashSet<CompressedStateEvent>, // added + HashSet<CompressedStateEvent>, // removed + )>, + > { + if let Some(r) = self + .stateinfo_cache + .lock() + .unwrap() + .get_mut(&shortstatehash) + { + return Ok(r.clone()); + } + + let StateDiff { + parent, + added, + removed, + } = self.db.get_statediff(shortstatehash)?; + + if let Some(parent) = parent { + let mut response = self.load_shortstatehash_info(parent)?; + let mut state = response.last().unwrap().1.clone(); + state.extend(added.iter().copied()); + for r in &removed { + state.remove(r); + } + + response.push((shortstatehash, state, added, removed)); + + Ok(response) + } else { + let response = vec![(shortstatehash, added.clone(), added, removed)]; + self.stateinfo_cache + .lock() + .unwrap() + .insert(shortstatehash, response.clone()); + Ok(response) + } + } + + pub fn compress_state_event( + &self, + shortstatekey: u64, + event_id: &EventId, + ) -> Result<CompressedStateEvent> { + let mut v = shortstatekey.to_be_bytes().to_vec(); + v.extend_from_slice( + &services() + .rooms + .short + .get_or_create_shorteventid(event_id)? + .to_be_bytes(), + ); + Ok(v.try_into().expect("we checked the size above")) + } + + /// Returns shortstatekey, event id + pub fn parse_compressed_state_event( + &self, + compressed_event: &CompressedStateEvent, + ) -> Result<(u64, Arc<EventId>)> { + Ok(( + utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]) + .expect("bytes have right length"), + services().rooms.short.get_eventid_from_short( + utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]) + .expect("bytes have right length"), + )?, + )) + } + + /// Creates a new shortstatehash that often is just a diff to an already existing + /// shortstatehash and therefore very efficient. + /// + /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer + /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0 + /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's + /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too. + /// + /// * `shortstatehash` - Shortstatehash of this state + /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid + /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid + /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer + /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer + #[tracing::instrument(skip( + self, + statediffnew, + statediffremoved, + diff_to_sibling, + parent_states + ))] + pub fn save_state_from_diff( + &self, + shortstatehash: u64, + statediffnew: HashSet<CompressedStateEvent>, + statediffremoved: HashSet<CompressedStateEvent>, + diff_to_sibling: usize, + mut parent_states: Vec<( + u64, // sstatehash + HashSet<CompressedStateEvent>, // full state + HashSet<CompressedStateEvent>, // added + HashSet<CompressedStateEvent>, // removed + )>, + ) -> Result<()> { + let diffsum = statediffnew.len() + statediffremoved.len(); + + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper + let parent = parent_states.pop().unwrap(); + + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + // It was not added in the parent and we removed it + parent_removed.insert(removed); + } + // Else it was added in the parent and we removed it again. We can forget this change + } + + for new in statediffnew { + if !parent_removed.remove(&new) { + // It was not touched in the parent and we added it + parent_new.insert(new); + } + // Else it was removed in the parent and we added it again. We can forget this change + } + + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + + return Ok(()); + } + + if parent_states.is_empty() { + // There is no parent layer, create a new state + self.db.save_statediff( + shortstatehash, + StateDiff { + parent: None, + added: statediffnew, + removed: statediffremoved, + }, + )?; + + return Ok(()); + }; + + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above + + let parent = parent_states.pop().unwrap(); + let parent_diff = parent.2.len() + parent.3.len(); + + if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + // Diff too big, we replace above layer(s) + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + // It was not added in the parent and we removed it + parent_removed.insert(removed); + } + // Else it was added in the parent and we removed it again. We can forget this change + } + + for new in statediffnew { + if !parent_removed.remove(&new) { + // It was not touched in the parent and we added it + parent_new.insert(new); + } + // Else it was removed in the parent and we added it again. We can forget this change + } + + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + } else { + // Diff small enough, we add diff as layer on top of parent + self.db.save_statediff( + shortstatehash, + StateDiff { + parent: Some(parent.0), + added: statediffnew, + removed: statediffremoved, + }, + )?; + } + + Ok(()) + } + + /// Returns the new shortstatehash, and the state diff from the previous room state + pub fn save_state( + &self, + room_id: &RoomId, + new_state_ids_compressed: HashSet<CompressedStateEvent>, + ) -> Result<( + u64, + HashSet<CompressedStateEvent>, + HashSet<CompressedStateEvent>, + )> { + let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; + + let state_hash = utils::calculate_hash( + &new_state_ids_compressed + .iter() + .map(|bytes| &bytes[..]) + .collect::<Vec<_>>(), + ); + + let (new_shortstatehash, already_existed) = services() + .rooms + .short + .get_or_create_shortstatehash(&state_hash)?; + + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok((new_shortstatehash, HashSet::new(), HashSet::new())); + } + + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() + { + let statediffnew: HashSet<_> = new_state_ids_compressed + .difference(&parent_stateinfo.1) + .copied() + .collect(); + + let statediffremoved: HashSet<_> = parent_stateinfo + .1 + .difference(&new_state_ids_compressed) + .copied() + .collect(); + + (statediffnew, statediffremoved) + } else { + (new_state_ids_compressed, HashSet::new()) + }; + + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 2, // every state change is 2 event changes on average + states_parents, + )?; + }; + + Ok((new_shortstatehash, statediffnew, statediffremoved)) + } +} diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs new file mode 100644 index 0000000..9377af0 --- /dev/null +++ b/src/service/rooms/timeline/data.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; + +use crate::{PduEvent, Result}; + +pub trait Data: Send + Sync { + fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>>; + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64>; + + /// Returns the `count` of this pdu's id. + fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>>; + + /// Returns the json of a pdu. + fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>; + + /// Returns the json of a pdu. + fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>; + + /// Returns the pdu's id. + fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>>; + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>; + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>>; + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>>; + + /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. + fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>>; + + /// Returns the `count` of this pdu's id. + fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64>; + + /// Adds a new pdu to the timeline + fn append_pdu( + &self, + pdu_id: &[u8], + pdu: &PduEvent, + json: &CanonicalJsonObject, + count: u64, + ) -> Result<()>; + + /// Removes a pdu and creates a new one with the same id. + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()>; + + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in chronological order. + fn pdus_since<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + since: u64, + ) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>>; + + /// Returns an iterator over all events and their tokens in a room that happened before the + /// event with id `until` in reverse-chronological order. + fn pdus_until<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + until: u64, + ) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>>; + + fn pdus_after<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + from: u64, + ) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>>; + + fn increment_notification_counts( + &self, + room_id: &RoomId, + notifies: Vec<OwnedUserId>, + highlights: Vec<OwnedUserId>, + ) -> Result<()>; +} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs new file mode 100644 index 0000000..34399d4 --- /dev/null +++ b/src/service/rooms/timeline/mod.rs @@ -0,0 +1,831 @@ +mod data; + +use std::collections::HashMap; + +use std::{ + collections::HashSet, + sync::{Arc, Mutex}, +}; + +pub use data::Data; +use regex::Regex; +use ruma::{ + api::client::error::ErrorKind, + canonical_json::to_canonical_value, + events::{ + push_rules::PushRulesEvent, + room::{ + create::RoomCreateEventContent, member::MembershipState, + power_levels::RoomPowerLevelsEventContent, + }, + GlobalAccountDataEventType, RoomEventType, StateEventType, + }, + push::{Action, Ruleset, Tweak}, + state_res, + state_res::RoomVersion, + uint, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, + OwnedServerName, RoomAliasId, RoomId, UserId, +}; +use serde::Deserialize; +use serde_json::value::to_raw_value; +use tokio::sync::MutexGuard; +use tracing::{error, warn}; + +use crate::{ + service::pdu::{EventHash, PduBuilder}, + services, utils, Error, PduEvent, Result, +}; + +use super::state_compressor::CompressedStateEvent; + +pub struct Service { + pub db: &'static dyn Data, + + pub lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, u64>>, +} + +impl Service { + #[tracing::instrument(skip(self))] + pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { + self.db.first_pdu_in_room(room_id) + } + + #[tracing::instrument(skip(self))] + pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> { + self.db.last_timeline_count(sender_user, room_id) + } + + // TODO Is this the same as the function above? + /* + #[tracing::instrument(skip(self))] + pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.pduid_pdu + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|b| self.pdu_count(&b.0)) + .transpose() + .map(|op| op.unwrap_or_default()) + } + */ + + /// Returns the `count` of this pdu's id. + pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> { + self.db.get_pdu_count(event_id) + } + + /// Returns the json of a pdu. + pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { + self.db.get_pdu_json(event_id) + } + + /// Returns the json of a pdu. + pub fn get_non_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result<Option<CanonicalJsonObject>> { + self.db.get_non_outlier_pdu_json(event_id) + } + + /// Returns the pdu's id. + pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { + self.db.get_pdu_id(event_id) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { + self.db.get_non_outlier_pdu(event_id) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { + self.db.get_pdu(event_id) + } + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { + self.db.get_pdu_from_id(pdu_id) + } + + /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { + self.db.get_pdu_json_from_id(pdu_id) + } + + /// Returns the `count` of this pdu's id. + pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> { + self.db.pdu_count(pdu_id) + } + + /// Removes a pdu and creates a new one with the same id. + #[tracing::instrument(skip(self))] + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { + self.db.replace_pdu(pdu_id, pdu) + } + + /// Creates a new persisted data unit and adds it to a room. + /// + /// By this point the incoming event should be fully authenticated, no auth happens + /// in `append_pdu`. + /// + /// Returns pdu id + #[tracing::instrument(skip(self, pdu, pdu_json, leaves))] + pub fn append_pdu<'a>( + &self, + pdu: &PduEvent, + mut pdu_json: CanonicalJsonObject, + leaves: Vec<OwnedEventId>, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<Vec<u8>> { + let shortroomid = services() + .rooms + .short + .get_shortroomid(&pdu.room_id)? + .expect("room exists"); + + // Make unsigned fields correct. This is not properly documented in the spec, but state + // events need to have previous content in the unsigned field, so clients can easily + // interpret things like membership changes + if let Some(state_key) = &pdu.state_key { + if let CanonicalJsonValue::Object(unsigned) = pdu_json + .entry("unsigned".to_owned()) + .or_insert_with(|| CanonicalJsonValue::Object(Default::default())) + { + if let Some(shortstatehash) = services() + .rooms + .state_accessor + .pdu_shortstatehash(&pdu.event_id) + .unwrap() + { + if let Some(prev_state) = services() + .rooms + .state_accessor + .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) + .unwrap() + { + unsigned.insert( + "prev_content".to_owned(), + CanonicalJsonValue::Object( + utils::to_canonical_object(prev_state.content.clone()) + .expect("event is valid, we just created it"), + ), + ); + } + } + } else { + error!("Invalid unsigned type in pdu."); + } + } + + // We must keep track of all events that have been referenced. + services() + .rooms + .pdu_metadata + .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + services() + .rooms + .state + .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(pdu.room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + + let count1 = services().globals.next_count()?; + // Mark as read first so the sending client doesn't get a notification even if appending + // fails + services() + .rooms + .edus + .read_receipt + .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + services() + .rooms + .user + .reset_notification_counts(&pdu.sender, &pdu.room_id)?; + + let count2 = services().globals.next_count()?; + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); + pdu_id.extend_from_slice(&count2.to_be_bytes()); + + // Insert pdu + self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; + + drop(insert_lock); + + // See if the event matches any known pushers + let power_levels: RoomPowerLevelsEventContent = services() + .rooms + .state_accessor + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + let sync_pdu = pdu.to_sync_room_event(); + + let mut notifies = Vec::new(); + let mut highlights = Vec::new(); + + for user in services() + .rooms + .state_cache + .get_our_real_users(&pdu.room_id)? + .iter() + { + // Don't notify the user of their own events + if user == &pdu.sender { + continue; + } + + let rules_for_user = services() + .account_data + .get( + None, + user, + GlobalAccountDataEventType::PushRules.to_string().into(), + )? + .map(|event| { + serde_json::from_str::<PushRulesEvent>(event.get()) + .map_err(|_| Error::bad_database("Invalid push rules event in db.")) + }) + .transpose()? + .map(|ev: PushRulesEvent| ev.content.global) + .unwrap_or_else(|| Ruleset::server_default(user)); + + let mut highlight = false; + let mut notify = false; + + for action in services().pusher.get_actions( + user, + &rules_for_user, + &power_levels, + &sync_pdu, + &pdu.room_id, + )? { + match action { + Action::DontNotify => notify = false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => notify = true, + Action::SetTweak(Tweak::Highlight(true)) => { + highlight = true; + } + _ => {} + }; + } + + if notify { + notifies.push(user.clone()); + } + + if highlight { + highlights.push(user.clone()); + } + + for push_key in services().pusher.get_pushkeys(user) { + services().sending.send_push_pdu(&pdu_id, user, push_key?)?; + } + } + + self.db + .increment_notification_counts(&pdu.room_id, notifies, highlights)?; + + match pdu.kind { + RoomEventType::RoomRedaction => { + if let Some(redact_id) = &pdu.redacts { + self.redact_pdu(redact_id, pdu)?; + } + } + RoomEventType::RoomMember => { + if let Some(state_key) = &pdu.state_key { + #[derive(Deserialize)] + struct ExtractMembership { + membership: MembershipState, + } + + // if the state_key fails + let target_user_id = UserId::parse(state_key.clone()) + .expect("This state_key was previously validated"); + + let content = serde_json::from_str::<ExtractMembership>(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + let invite_state = match content.membership { + MembershipState::Invite => { + let state = services().rooms.state.calculate_invite_state(pdu)?; + Some(state) + } + _ => None, + }; + + // Update our membership info, we do this here incase a user is invited + // and immediately leaves we need the DB to record the invite event for auth + services().rooms.state_cache.update_membership( + &pdu.room_id, + &target_user_id, + content.membership, + &pdu.sender, + invite_state, + true, + )?; + } + } + RoomEventType::RoomMessage => { + #[derive(Deserialize)] + struct ExtractBody { + body: Option<String>, + } + + let content = serde_json::from_str::<ExtractBody>(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + if let Some(body) = content.body { + services() + .rooms + .search + .index_pdu(shortroomid, &pdu_id, &body)?; + + let admin_room = services().rooms.alias.resolve_local_alias( + <&RoomAliasId>::try_from( + format!("#admins:{}", services().globals.server_name()).as_str(), + ) + .expect("#admins:server_name is a valid room alias"), + )?; + let server_user = format!("@conduit:{}", services().globals.server_name()); + + let to_conduit = body.starts_with(&format!("{server_user}: ")); + + // This will evaluate to false if the emergency password is set up so that + // the administrator can execute commands as conduit + let from_conduit = pdu.sender == server_user + && services().globals.emergency_password().is_none(); + + if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { + services().admin.process_message(body); + } + } + } + _ => {} + } + + for appservice in services().appservice.all()? { + if services() + .rooms + .state_cache + .appservice_in_room(&pdu.room_id, &appservice)? + { + services() + .sending + .send_pdu_appservice(appservice.0, pdu_id.clone())?; + continue; + } + + // If the RoomMember event has a non-empty state_key, it is targeted at someone. + // If it is our appservice user, we send this PDU to it. + if pdu.kind == RoomEventType::RoomMember { + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + if let Some(appservice_uid) = appservice + .1 + .get("sender_localpart") + .and_then(|string| string.as_str()) + .and_then(|string| { + UserId::parse_with_server_name(string, services().globals.server_name()) + .ok() + }) + { + if state_key_uid == &appservice_uid { + services() + .sending + .send_pdu_appservice(appservice.0, pdu_id.clone())?; + continue; + } + } + } + } + + if let Some(namespaces) = appservice.1.get("namespaces") { + let users = namespaces + .get("users") + .and_then(|users| users.as_sequence()) + .map_or_else(Vec::new, |users| { + users + .iter() + .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) + .collect::<Vec<_>>() + }); + let aliases = namespaces + .get("aliases") + .and_then(|aliases| aliases.as_sequence()) + .map_or_else(Vec::new, |aliases| { + aliases + .iter() + .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) + .collect::<Vec<_>>() + }); + let rooms = namespaces + .get("rooms") + .and_then(|rooms| rooms.as_sequence()); + + let matching_users = |users: &Regex| { + users.is_match(pdu.sender.as_str()) + || pdu.kind == RoomEventType::RoomMember + && pdu + .state_key + .as_ref() + .map_or(false, |state_key| users.is_match(state_key)) + }; + let matching_aliases = |aliases: &Regex| { + services() + .rooms + .alias + .local_aliases_for_room(&pdu.room_id) + .filter_map(|r| r.ok()) + .any(|room_alias| aliases.is_match(room_alias.as_str())) + }; + + if aliases.iter().any(matching_aliases) + || rooms.map_or(false, |rooms| rooms.contains(&pdu.room_id.as_str().into())) + || users.iter().any(matching_users) + { + services() + .sending + .send_pdu_appservice(appservice.0, pdu_id.clone())?; + } + } + } + + Ok(pdu_id) + } + + pub fn create_hash_and_sign_event( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<(PduEvent, CanonicalJsonObject)> { + let PduBuilder { + event_type, + content, + unsigned, + state_key, + redacts, + } = pdu_builder; + + let prev_events: Vec<_> = services() + .rooms + .state + .get_forward_extremities(room_id)? + .into_iter() + .take(20) + .collect(); + + let create_event = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomCreate, + "", + )?; + + let create_event_content: Option<RoomCreateEventContent> = create_event + .as_ref() + .map(|create_event| { + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::bad_database("Invalid create event in db.") + }) + }) + .transpose()?; + + // If there was no create event yet, assume we are creating a room with the default + // version right now + let room_version_id = create_event_content + .map_or(services().globals.default_room_version(), |create_event| { + create_event.room_version + }); + let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); + + let auth_events = services().rooms.state.get_auth_events( + room_id, + &event_type, + sender, + state_key.as_deref(), + &content, + )?; + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events + .iter() + .filter_map(|event_id| Some(services().rooms.timeline.get_pdu(event_id).ok()??.depth)) + .max() + .unwrap_or_else(|| uint!(0)) + + uint!(1); + + let mut unsigned = unsigned.unwrap_or_default(); + + if let Some(state_key) = &state_key { + if let Some(prev_pdu) = services().rooms.state_accessor.room_state_get( + room_id, + &event_type.to_string().into(), + state_key, + )? { + unsigned.insert( + "prev_content".to_owned(), + serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), + ); + unsigned.insert( + "prev_sender".to_owned(), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + ); + } + } + + let mut pdu = PduEvent { + event_id: ruma::event_id!("$thiswillbefilledinlater").into(), + room_id: room_id.to_owned(), + sender: sender.to_owned(), + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"), + kind: event_type, + content, + state_key, + prev_events, + depth, + auth_events: auth_events + .values() + .map(|pdu| pdu.event_id.clone()) + .collect(), + redacts, + unsigned: if unsigned.is_empty() { + None + } else { + Some(to_raw_value(&unsigned).expect("to_raw_value always works")) + }, + hashes: EventHash { + sha256: "aaa".to_owned(), + }, + signatures: None, + }; + + let auth_check = state_res::auth_check( + &room_version, + &pdu, + None::<PduEvent>, // TODO: third_party_invite + |k, s| auth_events.get(&(k.clone(), s.to_owned())), + ) + .map_err(|e| { + error!("{:?}", e); + Error::bad_database("Auth check failed.") + })?; + + if !auth_check { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Event is not authorized.", + )); + } + + // Hash and sign + let mut pdu_json = + utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); + + pdu_json.remove("event_id"); + + // Add origin because synapse likes that (and it's required in the spec) + pdu_json.insert( + "origin".to_owned(), + to_canonical_value(services().globals.server_name()) + .expect("server name is a valid CanonicalJsonValue"), + ); + + match ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut pdu_json, + &room_version_id, + ) { + Ok(_) => {} + Err(e) => { + return match e { + ruma::signatures::Error::PduSize => Err(Error::BadRequest( + ErrorKind::TooLarge, + "Message is too long", + )), + _ => Err(Error::BadRequest( + ErrorKind::Unknown, + "Signing event failed", + )), + } + } + } + + // Generate event id + pdu.event_id = EventId::parse_arc(format!( + "${}", + ruma::signatures::reference_hash(&pdu_json, &room_version_id) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + pdu_json.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), + ); + + // Generate short event id + let _shorteventid = services() + .rooms + .short + .get_or_create_shorteventid(&pdu.event_id)?; + + Ok((pdu, pdu_json)) + } + + /// Creates a new persisted data unit and adds it to a room. This function takes a + /// roomid_mutex_state, meaning that only this function is able to mutate the room state. + #[tracing::instrument(skip(self, state_lock))] + pub fn build_and_append_pdu( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<Arc<EventId>> { + let (pdu, pdu_json) = + self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; + + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + let statehashid = services().rooms.state.append_to_state(&pdu)?; + + let pdu_id = self.append_pdu( + &pdu, + pdu_json, + // Since this PDU references all pdu_leaves we can update the leaves + // of the room + vec![(*pdu.event_id).to_owned()], + state_lock, + )?; + + // We set the room state after inserting the pdu, so that we never have a moment in time + // where events in the current room state do not exist + services() + .rooms + .state + .set_room_state(room_id, statehashid, state_lock)?; + + let mut servers: HashSet<OwnedServerName> = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(|r| r.ok()) + .collect(); + + // In case we are kicking or banning a user, we need to inform their server of the change + if pdu.kind == RoomEventType::RoomMember { + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + servers.insert(state_key_uid.server_name().to_owned()); + } + } + + // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above + servers.remove(services().globals.server_name()); + + services().sending.send_pdu(servers.into_iter(), &pdu_id)?; + + Ok(pdu.event_id) + } + + /// Append the incoming event setting the state snapshot to the state from the + /// server that sent the event. + #[tracing::instrument(skip_all)] + pub fn append_incoming_pdu<'a>( + &self, + pdu: &PduEvent, + pdu_json: CanonicalJsonObject, + new_room_leaves: Vec<OwnedEventId>, + state_ids_compressed: HashSet<CompressedStateEvent>, + soft_fail: bool, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<Option<Vec<u8>>> { + // We append to state before appending the pdu, so we don't have a moment in time with the + // pdu without it's state. This is okay because append_pdu can't fail. + services().rooms.state.set_event_state( + &pdu.event_id, + &pdu.room_id, + state_ids_compressed, + )?; + + if soft_fail { + services() + .rooms + .pdu_metadata + .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + services().rooms.state.set_forward_extremities( + &pdu.room_id, + new_room_leaves, + state_lock, + )?; + return Ok(None); + } + + let pdu_id = + services() + .rooms + .timeline + .append_pdu(pdu, pdu_json, new_room_leaves, state_lock)?; + + Ok(Some(pdu_id)) + } + + /// Returns an iterator over all PDUs in a room. + pub fn all_pdus<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { + self.pdus_since(user_id, room_id, 0) + } + + /// Returns an iterator over all events in a room that happened after the event with id `since` + /// in chronological order. + pub fn pdus_since<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + since: u64, + ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { + self.db.pdus_since(user_id, room_id, since) + } + + /// Returns an iterator over all events and their tokens in a room that happened before the + /// event with id `until` in reverse-chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_until<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + until: u64, + ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { + self.db.pdus_until(user_id, room_id, until) + } + + /// Returns an iterator over all events and their token in a room that happened after the event + /// with id `from` in chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_after<'a>( + &'a self, + user_id: &UserId, + room_id: &RoomId, + from: u64, + ) -> Result<impl Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a> { + self.db.pdus_after(user_id, room_id, from) + } + + /// Replace a PDU with the redacted form. + #[tracing::instrument(skip(self, reason))] + pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { + if let Some(pdu_id) = self.get_pdu_id(event_id)? { + let mut pdu = self + .get_pdu_from_id(&pdu_id)? + .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + pdu.redact(reason)?; + self.replace_pdu(&pdu_id, &pdu)?; + } + // If event does not exist, just noop + Ok(()) + } +} diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs new file mode 100644 index 0000000..4b8a4ec --- /dev/null +++ b/src/service/rooms/user/data.rs @@ -0,0 +1,27 @@ +use crate::Result; +use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; + +pub trait Data: Send + Sync { + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; + + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; + + // Returns the count at which the last reset_notification_counts was called + fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64>; + + fn associate_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + shortstatehash: u64, + ) -> Result<()>; + + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>>; + + fn get_shared_rooms<'a>( + &'a self, + users: Vec<OwnedUserId>, + ) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>>; +} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs new file mode 100644 index 0000000..672e502 --- /dev/null +++ b/src/service/rooms/user/mod.rs @@ -0,0 +1,49 @@ +mod data; + +pub use data::Data; +use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; + +use crate::Result; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + self.db.reset_notification_counts(user_id, room_id) + } + + pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + self.db.notification_count(user_id, room_id) + } + + pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + self.db.highlight_count(user_id, room_id) + } + + pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { + self.db.last_notification_read(user_id, room_id) + } + + pub fn associate_token_shortstatehash( + &self, + room_id: &RoomId, + token: u64, + shortstatehash: u64, + ) -> Result<()> { + self.db + .associate_token_shortstatehash(room_id, token, shortstatehash) + } + + pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { + self.db.get_token_shortstatehash(room_id, token) + } + + pub fn get_shared_rooms( + &self, + users: Vec<OwnedUserId>, + ) -> Result<impl Iterator<Item = Result<OwnedRoomId>>> { + self.db.get_shared_rooms(users) + } +} diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs new file mode 100644 index 0000000..2e574e2 --- /dev/null +++ b/src/service/sending/data.rs @@ -0,0 +1,29 @@ +use ruma::ServerName; + +use crate::Result; + +use super::{OutgoingKind, SendingEventType}; + +pub trait Data: Send + Sync { + fn active_requests<'a>( + &'a self, + ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a>; + fn active_requests_for<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a>; + fn delete_active_request(&self, key: Vec<u8>) -> Result<()>; + fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; + fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; + fn queue_requests( + &self, + requests: &[(&OutgoingKind, SendingEventType)], + ) -> Result<Vec<Vec<u8>>>; + fn queued_requests<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a>; + fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()>; + fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; + fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64>; +} diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs new file mode 100644 index 0000000..1861feb --- /dev/null +++ b/src/service/sending/mod.rs @@ -0,0 +1,708 @@ +mod data; + +pub use data::Data; + +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + fmt::Debug, + sync::Arc, + time::{Duration, Instant}, +}; + +use crate::{ + api::{appservice_server, server_server}, + services, + utils::calculate_hash, + Config, Error, PduEvent, Result, +}; +use federation::transactions::send_transaction_message; +use futures_util::{stream::FuturesUnordered, StreamExt}; + +use ruma::{ + api::{ + appservice, + federation::{ + self, + transactions::edu::{ + DeviceListUpdateContent, Edu, ReceiptContent, ReceiptData, ReceiptMap, + }, + }, + OutgoingRequest, + }, + device_id, + events::{ + push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, + GlobalAccountDataEventType, + }, + push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, ServerName, UInt, UserId, +}; +use tokio::{ + select, + sync::{mpsc, Mutex, Semaphore}, +}; +use tracing::{error, warn}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum OutgoingKind { + Appservice(String), + Push(OwnedUserId, String), // user and pushkey + Normal(OwnedServerName), +} + +impl OutgoingKind { + #[tracing::instrument(skip(self))] + pub fn get_prefix(&self) -> Vec<u8> { + let mut prefix = match self { + OutgoingKind::Appservice(server) => { + let mut p = b"+".to_vec(); + p.extend_from_slice(server.as_bytes()); + p + } + OutgoingKind::Push(user, pushkey) => { + let mut p = b"$".to_vec(); + p.extend_from_slice(user.as_bytes()); + p.push(0xff); + p.extend_from_slice(pushkey.as_bytes()); + p + } + OutgoingKind::Normal(server) => { + let mut p = Vec::new(); + p.extend_from_slice(server.as_bytes()); + p + } + }; + prefix.push(0xff); + + prefix + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum SendingEventType { + Pdu(Vec<u8>), // pduid + Edu(Vec<u8>), // pdu json +} + +pub struct Service { + db: &'static dyn Data, + + /// The state for a given state hash. + pub(super) maximum_requests: Arc<Semaphore>, + pub sender: mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec<u8>)>, + receiver: Mutex<mpsc::UnboundedReceiver<(OutgoingKind, SendingEventType, Vec<u8>)>>, +} + +enum TransactionStatus { + Running, + Failed(u32, Instant), // number of times failed, time of last failure + Retrying(u32), // number of times failed +} + +impl Service { + pub fn build(db: &'static dyn Data, config: &Config) -> Arc<Self> { + let (sender, receiver) = mpsc::unbounded_channel(); + Arc::new(Self { + db, + sender, + receiver: Mutex::new(receiver), + maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), + }) + } + + pub fn start_handler(self: &Arc<Self>) { + let self2 = Arc::clone(self); + tokio::spawn(async move { + self2.handler().await.unwrap(); + }); + } + + async fn handler(&self) -> Result<()> { + let mut receiver = self.receiver.lock().await; + + let mut futures = FuturesUnordered::new(); + + let mut current_transaction_status = HashMap::<OutgoingKind, TransactionStatus>::new(); + + // Retry requests we could not finish yet + let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new(); + + for (key, outgoing_kind, event) in self.db.active_requests().filter_map(|r| r.ok()) { + let entry = initial_transactions + .entry(outgoing_kind.clone()) + .or_insert_with(Vec::new); + + if entry.len() > 30 { + warn!( + "Dropping some current events: {:?} {:?} {:?}", + key, outgoing_kind, event + ); + self.db.delete_active_request(key)?; + continue; + } + + entry.push(event); + } + + for (outgoing_kind, events) in initial_transactions { + current_transaction_status.insert(outgoing_kind.clone(), TransactionStatus::Running); + futures.push(Self::handle_events(outgoing_kind.clone(), events)); + } + + loop { + select! { + Some(response) = futures.next() => { + match response { + Ok(outgoing_kind) => { + self.db.delete_all_active_requests_for(&outgoing_kind)?; + + // Find events that have been added since starting the last request + let new_events = self.db.queued_requests(&outgoing_kind).filter_map(|r| r.ok()).take(30).collect::<Vec<_>>(); + + if !new_events.is_empty() { + // Insert pdus we found + self.db.mark_as_active(&new_events)?; + + futures.push( + Self::handle_events( + outgoing_kind.clone(), + new_events.into_iter().map(|(event, _)| event).collect(), + ) + ); + } else { + current_transaction_status.remove(&outgoing_kind); + } + } + Err((outgoing_kind, _)) => { + current_transaction_status.entry(outgoing_kind).and_modify(|e| *e = match e { + TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), + TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n+1, Instant::now()), + TransactionStatus::Failed(_, _) => { + error!("Request that was not even running failed?!"); + return + }, + }); + } + }; + }, + Some((outgoing_kind, event, key)) = receiver.recv() => { + if let Ok(Some(events)) = self.select_events( + &outgoing_kind, + vec![(event, key)], + &mut current_transaction_status, + ) { + futures.push(Self::handle_events(outgoing_kind, events)); + } + } + } + } + } + + #[tracing::instrument(skip(self, outgoing_kind, new_events, current_transaction_status))] + fn select_events( + &self, + outgoing_kind: &OutgoingKind, + new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key + current_transaction_status: &mut HashMap<OutgoingKind, TransactionStatus>, + ) -> Result<Option<Vec<SendingEventType>>> { + let mut retry = false; + let mut allow = true; + + let entry = current_transaction_status.entry(outgoing_kind.clone()); + + entry + .and_modify(|e| match e { + TransactionStatus::Running | TransactionStatus::Retrying(_) => { + allow = false; // already running + } + TransactionStatus::Failed(tries, time) => { + // Fail if a request has failed recently (exponential backoff) + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + allow = false; + } else { + retry = true; + *e = TransactionStatus::Retrying(*tries); + } + } + }) + .or_insert(TransactionStatus::Running); + + if !allow { + return Ok(None); + } + + let mut events = Vec::new(); + + if retry { + // We retry the previous transaction + for (_, e) in self + .db + .active_requests_for(outgoing_kind) + .filter_map(|r| r.ok()) + { + events.push(e); + } + } else { + self.db.mark_as_active(&new_events)?; + for (e, _) in new_events { + events.push(e); + } + + if let OutgoingKind::Normal(server_name) = outgoing_kind { + if let Ok((select_edus, last_count)) = self.select_edus(server_name) { + events.extend(select_edus.into_iter().map(SendingEventType::Edu)); + + self.db.set_latest_educount(server_name, last_count)?; + } + } + } + + Ok(Some(events)) + } + + #[tracing::instrument(skip(self, server_name))] + pub fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { + // u64: count of last edu + let since = self.db.get_latest_educount(server_name)?; + let mut events = Vec::new(); + let mut max_edu_count = since; + let mut device_list_changes = HashSet::new(); + + 'outer: for room_id in services().rooms.state_cache.server_rooms(server_name) { + let room_id = room_id?; + // Look for device list updates in this room + device_list_changes.extend( + services() + .users + .keys_changed(room_id.as_ref(), since, None) + .filter_map(|r| r.ok()) + .filter(|user_id| user_id.server_name() == services().globals.server_name()), + ); + + // Look for read receipts in this room + for r in services() + .rooms + .edus + .read_receipt + .readreceipts_since(&room_id, since) + { + let (user_id, count, read_receipt) = r?; + + if count > max_edu_count { + max_edu_count = count; + } + + if user_id.server_name() != services().globals.server_name() { + continue; + } + + let event: AnySyncEphemeralRoomEvent = + serde_json::from_str(read_receipt.json().get()) + .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = match event { + AnySyncEphemeralRoomEvent::Receipt(r) => { + let mut read = BTreeMap::new(); + + let (event_id, mut receipt) = r + .content + .0 + .into_iter() + .next() + .expect("we only use one event per read receipt"); + let receipt = receipt + .remove(&ReceiptType::Read) + .expect("our read receipts always set this") + .remove(&user_id) + .expect("our read receipts always have the user here"); + + read.insert( + user_id, + ReceiptData { + data: receipt.clone(), + event_ids: vec![event_id.clone()], + }, + ); + + let receipt_map = ReceiptMap { read }; + + let mut receipts = BTreeMap::new(); + receipts.insert(room_id.clone(), receipt_map); + + Edu::Receipt(ReceiptContent { receipts }) + } + _ => { + Error::bad_database("Invalid event type in read_receipts"); + continue; + } + }; + + events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); + + if events.len() >= 20 { + break 'outer; + } + } + } + + for user_id in device_list_changes { + // Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767 + // Because synapse resyncs, we can just insert dummy data + let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { + user_id, + device_id: device_id!("dummy").to_owned(), + device_display_name: Some("Dummy".to_owned()), + stream_id: uint!(1), + prev_id: Vec::new(), + deleted: None, + keys: None, + }); + + events.push(serde_json::to_vec(&edu).expect("json can be serialized")); + } + + Ok((events, max_edu_count)) + } + + #[tracing::instrument(skip(self, pdu_id, user, pushkey))] + pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); + let event = SendingEventType::Pdu(pdu_id.to_owned()); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self, servers, pdu_id))] + pub fn send_pdu<I: Iterator<Item = OwnedServerName>>( + &self, + servers: I, + pdu_id: &[u8], + ) -> Result<()> { + let requests = servers + .into_iter() + .map(|server| { + ( + OutgoingKind::Normal(server), + SendingEventType::Pdu(pdu_id.to_owned()), + ) + }) + .collect::<Vec<_>>(); + let keys = self.db.queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::<Vec<_>>(), + )?; + for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { + self.sender + .send((outgoing_kind.to_owned(), event, key)) + .unwrap(); + } + + Ok(()) + } + + #[tracing::instrument(skip(self, server, serialized))] + pub fn send_reliable_edu( + &self, + server: &ServerName, + serialized: Vec<u8>, + id: u64, + ) -> Result<()> { + let outgoing_kind = OutgoingKind::Normal(server.to_owned()); + let event = SendingEventType::Edu(serialized); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> { + let outgoing_kind = OutgoingKind::Appservice(appservice_id); + let event = SendingEventType::Pdu(pdu_id); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); + + Ok(()) + } + + /// Cleanup event data + /// Used for instance after we remove an appservice registration + /// + #[tracing::instrument(skip(self))] + pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + self.db + .delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; + + Ok(()) + } + + #[tracing::instrument(skip(events, kind))] + async fn handle_events( + kind: OutgoingKind, + events: Vec<SendingEventType>, + ) -> Result<OutgoingKind, (OutgoingKind, Error)> { + match &kind { + OutgoingKind::Appservice(id) => { + let mut pdu_jsons = Vec::new(); + + for event in &events { + match event { + SendingEventType::Pdu(pdu_id) => { + pdu_jsons.push(services().rooms.timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (kind.clone(), e))? + .ok_or_else(|| { + ( + kind.clone(), + Error::bad_database( + "[Appservice] Event in servernameevent_data not found in db.", + ), + ) + })? + .to_room_event()) + } + SendingEventType::Edu(_) => { + // Appservices don't need EDUs (?) + } + } + } + + let permit = services().sending.maximum_requests.acquire().await; + + let response = appservice_server::send_request( + services() + .appservice + .get_registration(id) + .map_err(|e| (kind.clone(), e))? + .ok_or_else(|| { + ( + kind.clone(), + Error::bad_database( + "[Appservice] Could not load registration from db.", + ), + ) + })?, + appservice::event::push_events::v1::Request { + events: pdu_jsons, + txn_id: (&*base64::encode_config( + calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, + }) + .collect::<Vec<_>>(), + ), + base64::URL_SAFE_NO_PAD, + )) + .into(), + }, + ) + .await + .map(|_response| kind.clone()) + .map_err(|e| (kind, e)); + + drop(permit); + + response + } + OutgoingKind::Push(userid, pushkey) => { + let mut pdus = Vec::new(); + + for event in &events { + match event { + SendingEventType::Pdu(pdu_id) => { + pdus.push( + services().rooms + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (kind.clone(), e))? + .ok_or_else(|| { + ( + kind.clone(), + Error::bad_database( + "[Push] Event in servernamevent_datas not found in db.", + ), + ) + })?, + ); + } + SendingEventType::Edu(_) => { + // Push gateways don't need EDUs (?) + } + } + } + + for pdu in pdus { + // Redacted events are not notification targets (we don't send push for them) + if let Some(unsigned) = &pdu.unsigned { + if let Ok(unsigned) = + serde_json::from_str::<serde_json::Value>(unsigned.get()) + { + if unsigned.get("redacted_because").is_some() { + continue; + } + } + } + + let pusher = match services() + .pusher + .get_pusher(userid, pushkey) + .map_err(|e| (OutgoingKind::Push(userid.clone(), pushkey.clone()), e))? + { + Some(pusher) => pusher, + None => continue, + }; + + let rules_for_user = services() + .account_data + .get( + None, + userid, + GlobalAccountDataEventType::PushRules.to_string().into(), + ) + .unwrap_or_default() + .and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok()) + .map(|ev: PushRulesEvent| ev.content.global) + .unwrap_or_else(|| push::Ruleset::server_default(userid)); + + let unread: UInt = services() + .rooms + .user + .notification_count(userid, &pdu.room_id) + .map_err(|e| (kind.clone(), e))? + .try_into() + .expect("notification count can't go that high"); + + let permit = services().sending.maximum_requests.acquire().await; + + let _response = services() + .pusher + .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) + .await + .map(|_response| kind.clone()) + .map_err(|e| (kind.clone(), e)); + + drop(permit); + } + Ok(OutgoingKind::Push(userid.clone(), pushkey.clone())) + } + OutgoingKind::Normal(server) => { + let mut edu_jsons = Vec::new(); + let mut pdu_jsons = Vec::new(); + + for event in &events { + match event { + SendingEventType::Pdu(pdu_id) => { + // TODO: check room version and remove event_id if needed + let raw = PduEvent::convert_to_outgoing_federation_event( + services().rooms + .timeline + .get_pdu_json_from_id(pdu_id) + .map_err(|e| (OutgoingKind::Normal(server.clone()), e))? + .ok_or_else(|| { + error!("event not found: {server} {pdu_id:?}"); + ( + OutgoingKind::Normal(server.clone()), + Error::bad_database( + "[Normal] Event in servernamevent_datas not found in db.", + ), + ) + })?, + ); + pdu_jsons.push(raw); + } + SendingEventType::Edu(edu) => { + if let Ok(raw) = serde_json::from_slice(edu) { + edu_jsons.push(raw); + } + } + } + } + + let permit = services().sending.maximum_requests.acquire().await; + + let response = server_server::send_request( + server, + send_transaction_message::v1::Request { + origin: services().globals.server_name().to_owned(), + pdus: pdu_jsons, + edus: edu_jsons, + origin_server_ts: MilliSecondsSinceUnixEpoch::now(), + transaction_id: (&*base64::encode_config( + calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, + }) + .collect::<Vec<_>>(), + ), + base64::URL_SAFE_NO_PAD, + )) + .into(), + }, + ) + .await + .map(|response| { + for pdu in response.pdus { + if pdu.1.is_err() { + warn!("Failed to send to {}: {:?}", server, pdu); + } + } + kind.clone() + }) + .map_err(|e| (kind, e)); + + drop(permit); + + response + } + } + } + + #[tracing::instrument(skip(self, destination, request))] + pub async fn send_federation_request<T: OutgoingRequest>( + &self, + destination: &ServerName, + request: T, + ) -> Result<T::IncomingResponse> + where + T: Debug, + { + let permit = self.maximum_requests.acquire().await; + let response = server_server::send_request(destination, request).await; + drop(permit); + + response + } + + #[tracing::instrument(skip(self, registration, request))] + pub async fn send_appservice_request<T: OutgoingRequest>( + &self, + registration: serde_yaml::Value, + request: T, + ) -> Result<T::IncomingResponse> + where + T: Debug, + { + let permit = self.maximum_requests.acquire().await; + let response = appservice_server::send_request(registration, request).await; + drop(permit); + + response + } +} diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs new file mode 100644 index 0000000..7485531 --- /dev/null +++ b/src/service/transaction_ids/data.rs @@ -0,0 +1,19 @@ +use crate::Result; +use ruma::{DeviceId, TransactionId, UserId}; + +pub trait Data: Send + Sync { + fn add_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + data: &[u8], + ) -> Result<()>; + + fn existing_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + ) -> Result<Option<Vec<u8>>>; +} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs new file mode 100644 index 0000000..2fa3b02 --- /dev/null +++ b/src/service/transaction_ids/mod.rs @@ -0,0 +1,31 @@ +mod data; + +pub use data::Data; + +use crate::Result; +use ruma::{DeviceId, TransactionId, UserId}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + pub fn add_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + data: &[u8], + ) -> Result<()> { + self.db.add_txnid(user_id, device_id, txn_id, data) + } + + pub fn existing_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + ) -> Result<Option<Vec<u8>>> { + self.db.existing_txnid(user_id, device_id, txn_id) + } +} diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs new file mode 100644 index 0000000..c64deb9 --- /dev/null +++ b/src/service/uiaa/data.rs @@ -0,0 +1,34 @@ +use crate::Result; +use ruma::{api::client::uiaa::UiaaInfo, CanonicalJsonValue, DeviceId, UserId}; + +pub trait Data: Send + Sync { + fn set_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + request: &CanonicalJsonValue, + ) -> Result<()>; + + fn get_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Option<CanonicalJsonValue>; + + fn update_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + uiaainfo: Option<&UiaaInfo>, + ) -> Result<()>; + + fn get_uiaa_session( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Result<UiaaInfo>; +} diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs new file mode 100644 index 0000000..147ce4d --- /dev/null +++ b/src/service/uiaa/mod.rs @@ -0,0 +1,145 @@ +mod data; + +pub use data::Data; + +use ruma::{ + api::client::{ + error::ErrorKind, + uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, + }, + CanonicalJsonValue, DeviceId, UserId, +}; +use tracing::error; + +use crate::{api::client_server::SESSION_ID_LENGTH, services, utils, Error, Result}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Creates a new Uiaa session. Make sure the session token is unique. + pub fn create( + &self, + user_id: &UserId, + device_id: &DeviceId, + uiaainfo: &UiaaInfo, + json_body: &CanonicalJsonValue, + ) -> Result<()> { + self.db.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?) + json_body, + )?; + self.db.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ) + } + + pub fn try_auth( + &self, + user_id: &UserId, + device_id: &DeviceId, + auth: &AuthData, + uiaainfo: &UiaaInfo, + ) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = auth + .session() + .map(|session| self.db.get_uiaa_session(user_id, device_id, session)) + .unwrap_or_else(|| Ok(uiaainfo.clone()))?; + + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } + + match auth { + // Find out what the user completed + AuthData::Password(Password { + identifier, + password, + .. + }) => { + let username = match identifier { + UserIdentifier::UserIdOrLocalpart(username) => username, + _ => { + return Err(Error::BadRequest( + ErrorKind::Unrecognized, + "Identifier type not recognized.", + )) + } + }; + + let user_id = UserId::parse_with_server_name( + username.clone(), + services().globals.server_name(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; + + // Check if password is correct + if let Some(hash) = services().users.password_hash(&user_id)? { + let hash_matches = + argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); + + if !hash_matches { + uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password.".to_owned(), + }); + return Ok((false, uiaainfo)); + } + } + + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + } + AuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + } + k => error!("type not supported: {:?}", k), + } + + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } + + if !completed { + self.db.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + Some(&uiaainfo), + )?; + return Ok((false, uiaainfo)); + } + + // UIAA was successful! Remove this session and return true + self.db.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + )?; + Ok((true, uiaainfo)) + } + + pub fn get_uiaa_request( + &self, + user_id: &UserId, + device_id: &DeviceId, + session: &str, + ) -> Option<CanonicalJsonValue> { + self.db.get_uiaa_request(user_id, device_id, session) + } +} diff --git a/src/service/users/data.rs b/src/service/users/data.rs new file mode 100644 index 0000000..8553210 --- /dev/null +++ b/src/service/users/data.rs @@ -0,0 +1,197 @@ +use crate::Result; +use ruma::{ + api::client::{device::Device, filter::FilterDefinition}, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + events::AnyToDeviceEvent, + serde::Raw, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, + OwnedUserId, UInt, UserId, +}; +use std::collections::BTreeMap; + +pub trait Data: Send + Sync { + /// Check if a user has an account on this homeserver. + fn exists(&self, user_id: &UserId) -> Result<bool>; + + /// Check if account is deactivated + fn is_deactivated(&self, user_id: &UserId) -> Result<bool>; + + /// Returns the number of users registered on this server. + fn count(&self) -> Result<usize>; + + /// Find out which user an access token belongs to. + fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>>; + + /// Returns an iterator over all users on this homeserver. + fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>; + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is greater then zero. + fn list_local_users(&self) -> Result<Vec<String>>; + + /// Returns the password hash for the given user. + fn password_hash(&self, user_id: &UserId) -> Result<Option<String>>; + + /// Hash and set the user's password to the Argon2 hash + fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; + + /// Returns the displayname of a user on this homeserver. + fn displayname(&self, user_id: &UserId) -> Result<Option<String>>; + + /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()>; + + /// Get the avatar_url of a user. + fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>>; + + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()>; + + /// Get the blurhash of a user. + fn blurhash(&self, user_id: &UserId) -> Result<Option<String>>; + + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()>; + + /// Adds a new device to a user. + fn create_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + initial_device_display_name: Option<String>, + ) -> Result<()>; + + /// Removes a device from a user. + fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + + /// Returns an iterator over all device ids of this user. + fn all_device_ids<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a>; + + /// Replaces the access token of one device. + fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; + + fn add_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw<OneTimeKey>, + ) -> Result<()>; + + fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64>; + + fn take_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + key_algorithm: &DeviceKeyAlgorithm, + ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>>; + + fn count_one_time_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>>; + + fn add_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + device_keys: &Raw<DeviceKeys>, + ) -> Result<()>; + + fn add_cross_signing_keys( + &self, + user_id: &UserId, + master_key: &Raw<CrossSigningKey>, + self_signing_key: &Option<Raw<CrossSigningKey>>, + user_signing_key: &Option<Raw<CrossSigningKey>>, + ) -> Result<()>; + + fn sign_key( + &self, + target_id: &UserId, + key_id: &str, + signature: (String, String), + sender_id: &UserId, + ) -> Result<()>; + + fn keys_changed<'a>( + &'a self, + user_or_room_id: &str, + from: u64, + to: Option<u64>, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a>; + + fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; + + fn get_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<Option<Raw<DeviceKeys>>>; + + fn get_master_key( + &self, + user_id: &UserId, + allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result<Option<Raw<CrossSigningKey>>>; + + fn get_self_signing_key( + &self, + user_id: &UserId, + allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result<Option<Raw<CrossSigningKey>>>; + + fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>>; + + fn add_to_device_event( + &self, + sender: &UserId, + target_user_id: &UserId, + target_device_id: &DeviceId, + event_type: &str, + content: serde_json::Value, + ) -> Result<()>; + + fn get_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<Vec<Raw<AnyToDeviceEvent>>>; + + fn remove_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + until: u64, + ) -> Result<()>; + + fn update_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + device: &Device, + ) -> Result<()>; + + /// Get device metadata. + fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) + -> Result<Option<Device>>; + + fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>>; + + fn all_devices_metadata<'a>( + &'a self, + user_id: &UserId, + ) -> Box<dyn Iterator<Item = Result<Device>> + 'a>; + + /// Creates a new sync filter. Returns the filter id. + fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String>; + + fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>>; +} diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs new file mode 100644 index 0000000..6be5c89 --- /dev/null +++ b/src/service/users/mod.rs @@ -0,0 +1,367 @@ +mod data; +use std::{collections::BTreeMap, mem}; + +pub use data::Data; +use ruma::{ + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + events::AnyToDeviceEvent, + serde::Raw, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, + OwnedUserId, RoomAliasId, UInt, UserId, +}; + +use crate::{services, Error, Result}; + +pub struct Service { + pub db: &'static dyn Data, +} + +impl Service { + /// Check if a user has an account on this homeserver. + pub fn exists(&self, user_id: &UserId) -> Result<bool> { + self.db.exists(user_id) + } + + /// Check if account is deactivated + pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { + self.db.is_deactivated(user_id) + } + + /// Check if a user is an admin + pub fn is_admin(&self, user_id: &UserId) -> Result<bool> { + let admin_room_alias_id = + RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; + let admin_room_id = services() + .rooms + .alias + .resolve_local_alias(&admin_room_alias_id)? + .unwrap(); + + services() + .rooms + .state_cache + .is_joined(user_id, &admin_room_id) + } + + /// Create a new user account on this homeserver. + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.db.set_password(user_id, password)?; + Ok(()) + } + + /// Returns the number of users registered on this server. + pub fn count(&self) -> Result<usize> { + self.db.count() + } + + /// Find out which user an access token belongs to. + pub fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> { + self.db.find_from_token(token) + } + + /// Returns an iterator over all users on this homeserver. + pub fn iter(&self) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { + self.db.iter() + } + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is greater then zero. + pub fn list_local_users(&self) -> Result<Vec<String>> { + self.db.list_local_users() + } + + /// Returns the password hash for the given user. + pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { + self.db.password_hash(user_id) + } + + /// Hash and set the user's password to the Argon2 hash + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.db.set_password(user_id, password) + } + + /// Returns the displayname of a user on this homeserver. + pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { + self.db.displayname(user_id) + } + + /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. + pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { + self.db.set_displayname(user_id, displayname) + } + + /// Get the avatar_url of a user. + pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<OwnedMxcUri>> { + self.db.avatar_url(user_id) + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> { + self.db.set_avatar_url(user_id, avatar_url) + } + + /// Get the blurhash of a user. + pub fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> { + self.db.blurhash(user_id) + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { + self.db.set_blurhash(user_id, blurhash) + } + + /// Adds a new device to a user. + pub fn create_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + token: &str, + initial_device_display_name: Option<String>, + ) -> Result<()> { + self.db + .create_device(user_id, device_id, token, initial_device_display_name) + } + + /// Removes a device from a user. + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.remove_device(user_id, device_id) + } + + /// Returns an iterator over all device ids of this user. + pub fn all_device_ids<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<OwnedDeviceId>> + 'a { + self.db.all_device_ids(user_id) + } + + /// Replaces the access token of one device. + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + self.db.set_token(user_id, device_id, token) + } + + pub fn add_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw<OneTimeKey>, + ) -> Result<()> { + self.db + .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) + } + + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { + self.db.last_one_time_keys_update(user_id) + } + + pub fn take_one_time_key( + &self, + user_id: &UserId, + device_id: &DeviceId, + key_algorithm: &DeviceKeyAlgorithm, + ) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> { + self.db.take_one_time_key(user_id, device_id, key_algorithm) + } + + pub fn count_one_time_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> { + self.db.count_one_time_keys(user_id, device_id) + } + + pub fn add_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + device_keys: &Raw<DeviceKeys>, + ) -> Result<()> { + self.db.add_device_keys(user_id, device_id, device_keys) + } + + pub fn add_cross_signing_keys( + &self, + user_id: &UserId, + master_key: &Raw<CrossSigningKey>, + self_signing_key: &Option<Raw<CrossSigningKey>>, + user_signing_key: &Option<Raw<CrossSigningKey>>, + ) -> Result<()> { + self.db + .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key) + } + + pub fn sign_key( + &self, + target_id: &UserId, + key_id: &str, + signature: (String, String), + sender_id: &UserId, + ) -> Result<()> { + self.db.sign_key(target_id, key_id, signature, sender_id) + } + + pub fn keys_changed<'a>( + &'a self, + user_or_room_id: &str, + from: u64, + to: Option<u64>, + ) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { + self.db.keys_changed(user_or_room_id, from, to) + } + + pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + self.db.mark_device_key_update(user_id) + } + + pub fn get_device_keys( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<Option<Raw<DeviceKeys>>> { + self.db.get_device_keys(user_id, device_id) + } + + pub fn get_master_key( + &self, + user_id: &UserId, + allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result<Option<Raw<CrossSigningKey>>> { + self.db.get_master_key(user_id, allowed_signatures) + } + + pub fn get_self_signing_key( + &self, + user_id: &UserId, + allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result<Option<Raw<CrossSigningKey>>> { + self.db.get_self_signing_key(user_id, allowed_signatures) + } + + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { + self.db.get_user_signing_key(user_id) + } + + pub fn add_to_device_event( + &self, + sender: &UserId, + target_user_id: &UserId, + target_device_id: &DeviceId, + event_type: &str, + content: serde_json::Value, + ) -> Result<()> { + self.db.add_to_device_event( + sender, + target_user_id, + target_device_id, + event_type, + content, + ) + } + + pub fn get_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<Vec<Raw<AnyToDeviceEvent>>> { + self.db.get_to_device_events(user_id, device_id) + } + + pub fn remove_to_device_events( + &self, + user_id: &UserId, + device_id: &DeviceId, + until: u64, + ) -> Result<()> { + self.db.remove_to_device_events(user_id, device_id, until) + } + + pub fn update_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + device: &Device, + ) -> Result<()> { + self.db.update_device_metadata(user_id, device_id, device) + } + + /// Get device metadata. + pub fn get_device_metadata( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result<Option<Device>> { + self.db.get_device_metadata(user_id, device_id) + } + + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { + self.db.get_devicelist_version(user_id) + } + + pub fn all_devices_metadata<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<Device>> + 'a { + self.db.all_devices_metadata(user_id) + } + + /// Deactivate account + pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + for device_id in self.all_device_ids(user_id) { + self.remove_device(user_id, &device_id?)?; + } + + // Set the password to "" to indicate a deactivated account. Hashes will never result in an + // empty string, so the user will not be able to log in again. Systems like changing the + // password without logging in should check if the account is deactivated. + self.db.set_password(user_id, None)?; + + // TODO: Unhook 3PID + Ok(()) + } + + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result<String> { + self.db.create_filter(user_id, filter) + } + + pub fn get_filter( + &self, + user_id: &UserId, + filter_id: &str, + ) -> Result<Option<FilterDefinition>> { + self.db.get_filter(user_id, filter_id) + } +} + +/// Ensure that a user only sees signatures from themselves and the target user +pub fn clean_signatures<F: Fn(&UserId) -> bool>( + cross_signing_key: &mut serde_json::Value, + user_id: &UserId, + allowed_signatures: F, +) -> Result<(), Error> { + if let Some(signatures) = cross_signing_key + .get_mut("signatures") + .and_then(|v| v.as_object_mut()) + { + // Don't allocate for the full size of the current signatures, but require + // at most one resize if nothing is dropped + let new_capacity = signatures.len() / 2; + for (user, signature) in + mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) + { + let id = <&UserId>::try_from(user.as_str()) + .map_err(|_| Error::bad_database("Invalid user ID in database."))?; + if id == user_id || allowed_signatures(id) { + signatures.insert(user, signature); + } + } + } + + Ok(()) +} diff --git a/src/error.rs b/src/utils/error.rs index 206a055..4f044ca 100644 --- a/src/error.rs +++ b/src/utils/error.rs @@ -3,10 +3,10 @@ use std::convert::Infallible; use http::StatusCode; use ruma::{ api::client::{ - error::{Error as RumaError, ErrorKind}, + error::{Error as RumaError, ErrorBody, ErrorKind}, uiaa::{UiaaInfo, UiaaResponse}, }, - ServerName, + OwnedServerName, }; use thiserror::Error; use tracing::{error, warn}; @@ -55,7 +55,7 @@ pub enum Error { source: reqwest::Error, }, #[error("{0}")] - FederationError(Box<ServerName>, RumaError), + FederationError(OwnedServerName, RumaError), #[error("Could not do this io: {source}")] IoError { #[from] @@ -102,11 +102,14 @@ impl Error { if let Self::FederationError(origin, error) = self { let mut error = error.clone(); - error.message = format!("Answer from {}: {}", origin, error.message); + error.body = ErrorBody::Standard { + kind: Unknown, + message: format!("Answer from {origin}: {error}"), + }; return RumaResponse(UiaaResponse::MatrixError(error)); } - let message = format!("{}", self); + let message = format!("{self}"); use ErrorKind::*; let (kind, status_code) = match self { @@ -117,7 +120,7 @@ impl Error { StatusCode::FORBIDDEN } Unauthorized | UnknownToken { .. } | MissingToken => StatusCode::UNAUTHORIZED, - NotFound => StatusCode::NOT_FOUND, + NotFound | Unrecognized => StatusCode::NOT_FOUND, LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, UserDeactivated => StatusCode::FORBIDDEN, TooLarge => StatusCode::PAYLOAD_TOO_LARGE, @@ -131,8 +134,7 @@ impl Error { warn!("{}: {}", status_code, message); RumaResponse(UiaaResponse::MatrixError(RumaError { - kind, - message, + body: ErrorBody::Standard { kind, message }, status_code, })) } diff --git a/src/utils.rs b/src/utils/mod.rs index 1ad0aa3..0b5b1ae 100644 --- a/src/utils.rs +++ b/src/utils/mod.rs @@ -1,7 +1,10 @@ +pub mod error; + use argon2::{Config, Variant}; use cmp::Ordering; use rand::prelude::*; -use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; +use ring::digest; +use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject}; use std::{ cmp, fmt, str::FromStr, @@ -57,7 +60,7 @@ pub fn random_string(length: usize) -> String { } /// Calculate a new hash for the given password -pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> { +pub fn calculate_password_hash(password: &str) -> Result<String, argon2::Error> { let hashing_config = Config { variant: Variant::Argon2id, ..Default::default() @@ -67,6 +70,14 @@ pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> { argon2::hash_encoded(password.as_bytes(), salt.as_bytes(), &hashing_config) } +#[tracing::instrument(skip(keys))] +pub fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> { + // We only hash the pdu's event ids, not the whole pdu + let bytes = keys.join(&0xff); + let hash = digest::digest(&digest::SHA256, &bytes); + hash.as_ref().to_owned() +} + pub fn common_elements( mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, diff --git a/tests/Complement.Dockerfile b/tests/Complement.Dockerfile index 22016e9..b9d0f8c 100644 --- a/tests/Complement.Dockerfile +++ b/tests/Complement.Dockerfile @@ -33,7 +33,7 @@ RUN sed -i "s/port = 6167/port = 8008/g" conduit.toml RUN echo "allow_federation = true" >> conduit.toml RUN echo "allow_encryption = true" >> conduit.toml RUN echo "allow_registration = true" >> conduit.toml -RUN echo "log = \"info,_=off,sled=off\"" >> conduit.toml +RUN echo "log = \"warn,_=off,sled=off\"" >> conduit.toml RUN sed -i "s/address = \"127.0.0.1\"/address = \"0.0.0.0\"/g" conduit.toml # Enabled Caddy auto cert generation for complement provided CA. |