mirror of
https://github.com/oddlama/nix-config.git
synced 2025-10-10 23:00:39 +02:00
feat: add my own realtime-stt-server script; add fixed jaxlib
This commit is contained in:
parent
5e9143778f
commit
e1e8997525
5 changed files with 823 additions and 0 deletions
|
@ -33,6 +33,15 @@
|
|||
wrapProgram $out/bin/nvim --add-flags "--clean"
|
||||
'';
|
||||
});
|
||||
pythonPackagesExtensions =
|
||||
prev.pythonPackagesExtensions
|
||||
++ [
|
||||
(pythonFinal: _pythonPrev: {
|
||||
jaxlib = pythonFinal.callPackage ./jaxlib.nix {};
|
||||
realtime-stt = pythonFinal.callPackage ./realtime-stt.nix {};
|
||||
})
|
||||
];
|
||||
realtime-stt-server = prev.callPackage ./realtime-stt-server.nix {};
|
||||
|
||||
formats =
|
||||
prev.formats
|
||||
|
|
451
pkgs/jaxlib.nix
Normal file
451
pkgs/jaxlib.nix
Normal file
|
@ -0,0 +1,451 @@
|
|||
{
|
||||
lib,
|
||||
pkgs,
|
||||
# Build-time dependencies:
|
||||
autoAddDriverRunpath,
|
||||
bazel_6,
|
||||
binutils,
|
||||
buildBazelPackage,
|
||||
buildPythonPackage,
|
||||
curl,
|
||||
cython,
|
||||
fetchFromGitHub,
|
||||
git,
|
||||
jsoncpp,
|
||||
nsync,
|
||||
openssl,
|
||||
pybind11,
|
||||
setuptools,
|
||||
symlinkJoin,
|
||||
wheel,
|
||||
build,
|
||||
which,
|
||||
# Python dependencies:
|
||||
absl-py,
|
||||
flatbuffers,
|
||||
ml-dtypes,
|
||||
numpy,
|
||||
scipy,
|
||||
six,
|
||||
# Runtime dependencies:
|
||||
double-conversion,
|
||||
giflib,
|
||||
libjpeg_turbo,
|
||||
python,
|
||||
snappy,
|
||||
zlib,
|
||||
config,
|
||||
# CUDA flags:
|
||||
cudaSupport ? config.cudaSupport,
|
||||
cudaPackages,
|
||||
# MKL:
|
||||
mklSupport ? true,
|
||||
} @ inputs: let
|
||||
inherit
|
||||
(cudaPackages)
|
||||
cudaFlags
|
||||
cudaVersion
|
||||
cudnn
|
||||
nccl
|
||||
;
|
||||
|
||||
pname = "jaxlib";
|
||||
version = "0.4.28";
|
||||
|
||||
# It's necessary to consistently use backendStdenv when building with CUDA
|
||||
# support, otherwise we get libstdc++ errors downstream
|
||||
stdenv = throw "Use effectiveStdenv instead";
|
||||
effectiveStdenv =
|
||||
if cudaSupport
|
||||
then cudaPackages.backendStdenv
|
||||
else inputs.stdenv;
|
||||
|
||||
meta = with lib; {
|
||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research";
|
||||
homepage = "https://github.com/google/jax";
|
||||
license = licenses.asl20;
|
||||
maintainers = with maintainers; [ndl];
|
||||
platforms = platforms.unix;
|
||||
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
|
||||
# however even with that fix applied, it doesn't work for everyone:
|
||||
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
|
||||
# NOTE: We always build with NCCL; if it is unsupported, then our build is broken.
|
||||
broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
|
||||
};
|
||||
|
||||
# These are necessary at build time and run time.
|
||||
cuda_libs_joined = symlinkJoin {
|
||||
name = "cuda-joined";
|
||||
paths = with cudaPackages; [
|
||||
cuda_cudart.lib # libcudart.so
|
||||
cuda_cudart.static # libcudart_static.a
|
||||
cuda_cupti.lib # libcupti.so
|
||||
libcublas.lib # libcublas.so
|
||||
libcufft.lib # libcufft.so
|
||||
libcurand.lib # libcurand.so
|
||||
libcusolver.lib # libcusolver.so
|
||||
libcusparse.lib # libcusparse.so
|
||||
];
|
||||
};
|
||||
# These are only necessary at build time.
|
||||
cuda_build_deps_joined = symlinkJoin {
|
||||
name = "cuda-build-deps-joined";
|
||||
paths = with cudaPackages; [
|
||||
cuda_libs_joined
|
||||
|
||||
# Binaries
|
||||
cudaPackages.cuda_nvcc.bin # nvcc
|
||||
|
||||
# Headers
|
||||
cuda_cccl.dev # block_load.cuh
|
||||
cuda_cudart.dev # cuda.h
|
||||
cuda_cupti.dev # cupti.h
|
||||
cuda_nvcc.dev # See https://github.com/google/jax/issues/19811
|
||||
cuda_nvml_dev # nvml.h
|
||||
cuda_nvtx.dev # nvToolsExt.h
|
||||
libcublas.dev # cublas_api.h
|
||||
libcufft.dev # cufft.h
|
||||
libcurand.dev # curand.h
|
||||
libcusolver.dev # cusolver_common.h
|
||||
libcusparse.dev # cusparse.h
|
||||
];
|
||||
};
|
||||
|
||||
backend_cc_joined = symlinkJoin {
|
||||
name = "cuda-cc-joined";
|
||||
paths = [
|
||||
effectiveStdenv.cc
|
||||
binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
|
||||
];
|
||||
};
|
||||
|
||||
# Copy-paste from TF derivation.
|
||||
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
|
||||
# 'as is' so that it's more compatible with TF derivation.
|
||||
tf_system_libs = [
|
||||
"absl_py"
|
||||
"astor_archive"
|
||||
"astunparse_archive"
|
||||
# Not packaged in nixpkgs
|
||||
# "com_github_googleapis_googleapis"
|
||||
# "com_github_googlecloudplatform_google_cloud_cpp"
|
||||
# Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
|
||||
# "com_github_grpc_grpc"
|
||||
# ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
|
||||
# target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
|
||||
# "com_google_protobuf"
|
||||
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
|
||||
# "com_googlesource_code_re2"
|
||||
"curl"
|
||||
"cython"
|
||||
"dill_archive"
|
||||
"double_conversion"
|
||||
"flatbuffers"
|
||||
"functools32_archive"
|
||||
"gast_archive"
|
||||
"gif"
|
||||
"hwloc"
|
||||
"icu"
|
||||
"jsoncpp_git"
|
||||
"libjpeg_turbo"
|
||||
"lmdb"
|
||||
"nasm"
|
||||
"opt_einsum_archive"
|
||||
"org_sqlite"
|
||||
"pasta"
|
||||
"png"
|
||||
# ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
|
||||
# target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
|
||||
# "pybind11"
|
||||
"six_archive"
|
||||
"snappy"
|
||||
"tblib_archive"
|
||||
"termcolor_archive"
|
||||
"typing_extensions_archive"
|
||||
"wrapt"
|
||||
"zlib"
|
||||
];
|
||||
|
||||
arch =
|
||||
# KeyError: ('Linux', 'arm64')
|
||||
if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64"
|
||||
then "aarch64"
|
||||
else effectiveStdenv.hostPlatform.linuxArch;
|
||||
|
||||
xla = effectiveStdenv.mkDerivation {
|
||||
pname = "xla-src";
|
||||
version = "unstable";
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "openxla";
|
||||
repo = "xla";
|
||||
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
|
||||
rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
|
||||
hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
|
||||
};
|
||||
|
||||
dontBuild = true;
|
||||
|
||||
# This is necessary for patchShebangs to know the right path to use.
|
||||
nativeBuildInputs = [python];
|
||||
|
||||
# Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
|
||||
postPatch = ''
|
||||
patchShebangs .
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
cp -r . $out
|
||||
'';
|
||||
};
|
||||
|
||||
bazel-build = buildBazelPackage rec {
|
||||
name = "bazel-build-${pname}-${version}";
|
||||
|
||||
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
|
||||
bazel = bazel_6;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "google";
|
||||
repo = "jax";
|
||||
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
|
||||
rev = "refs/tags/${pname}-v${version}";
|
||||
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
cython
|
||||
pkgs.flatbuffers
|
||||
git
|
||||
setuptools
|
||||
wheel
|
||||
build
|
||||
which
|
||||
];
|
||||
|
||||
buildInputs =
|
||||
[
|
||||
curl
|
||||
double-conversion
|
||||
giflib
|
||||
jsoncpp
|
||||
libjpeg_turbo
|
||||
numpy
|
||||
openssl
|
||||
pkgs.flatbuffers
|
||||
pkgs.protobuf
|
||||
pybind11
|
||||
scipy
|
||||
six
|
||||
snappy
|
||||
zlib
|
||||
]
|
||||
++ lib.optionals (!effectiveStdenv.isDarwin) [nsync];
|
||||
|
||||
# We don't want to be quite so picky regarding bazel version
|
||||
postPatch = ''
|
||||
rm -f .bazelversion
|
||||
'';
|
||||
|
||||
bazelRunTarget = "//jaxlib/tools:build_wheel";
|
||||
runTargetFlags = [
|
||||
"--output_path=$out"
|
||||
"--cpu=${arch}"
|
||||
# This has no impact whatsoever...
|
||||
"--jaxlib_git_hash='12345678'"
|
||||
];
|
||||
|
||||
removeRulesCC = false;
|
||||
|
||||
GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin";
|
||||
GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc";
|
||||
|
||||
# The version is automatically set to ".dev" if this variable is not set.
|
||||
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
|
||||
JAXLIB_RELEASE = "1";
|
||||
|
||||
preConfigure =
|
||||
# Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
|
||||
''
|
||||
mkdir dummy-ldconfig
|
||||
echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
|
||||
chmod +x dummy-ldconfig/ldconfig
|
||||
export PATH="$PWD/dummy-ldconfig:$PATH"
|
||||
''
|
||||
+
|
||||
# Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
|
||||
# for more info. We assume
|
||||
# * `cpu = None`
|
||||
# * `enable_nccl = True`
|
||||
# * `target_cpu_features = "release"`
|
||||
# * `rocm_amdgpu_targets = None`
|
||||
# * `enable_rocm = False`
|
||||
# * `build_gpu_plugin = False`
|
||||
# * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
|
||||
#
|
||||
# Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
|
||||
# instead of duplicating the logic here. Perhaps we can leverage the
|
||||
# `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
|
||||
''
|
||||
cat <<CFG > ./.jax_configure.bazelrc
|
||||
build --strategy=Genrule=standalone
|
||||
build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
|
||||
build --action_env=PYENV_ROOT
|
||||
build --python_path="${python}/bin/python"
|
||||
build --distinct_host_configuration=false
|
||||
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
|
||||
''
|
||||
+ lib.optionalString cudaSupport ''
|
||||
build --config=cuda
|
||||
build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
|
||||
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
|
||||
build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnn},${nccl}"
|
||||
build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}"
|
||||
build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
|
||||
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
|
||||
''
|
||||
+
|
||||
# Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
|
||||
# rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
|
||||
# good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
|
||||
# for upstream's version.
|
||||
lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix)
|
||||
''
|
||||
build --config=avx_posix
|
||||
''
|
||||
+ lib.optionalString mklSupport ''
|
||||
build --config=mkl_open_source_only
|
||||
''
|
||||
+ ''
|
||||
CFG
|
||||
'';
|
||||
|
||||
# Make sure Bazel knows about our configuration flags during fetching so that the
|
||||
# relevant dependencies can be downloaded.
|
||||
bazelFlags =
|
||||
[
|
||||
"-c opt"
|
||||
# See https://bazel.build/external/advanced#overriding-repositories for
|
||||
# information on --override_repository flag.
|
||||
"--override_repository=xla=${xla}"
|
||||
]
|
||||
++ lib.optionals effectiveStdenv.cc.isClang [
|
||||
# bazel depends on the compiler frontend automatically selecting these flags based on file
|
||||
# extension but our clang doesn't.
|
||||
# https://github.com/NixOS/nixpkgs/issues/150655
|
||||
"--cxxopt=-x"
|
||||
"--cxxopt=c++"
|
||||
"--host_cxxopt=-x"
|
||||
"--host_cxxopt=c++"
|
||||
];
|
||||
|
||||
# We intentionally overfetch so we can share the fetch derivation across all the different configurations
|
||||
fetchAttrs = {
|
||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
|
||||
# we have to force @mkl_dnn_v1 since it's not needed on darwin
|
||||
bazelTargets = [
|
||||
bazelRunTarget
|
||||
"@mkl_dnn_v1//:mkl_dnn"
|
||||
];
|
||||
bazelFlags =
|
||||
bazelFlags
|
||||
++ [
|
||||
"--config=avx_posix"
|
||||
"--config=mkl_open_source_only"
|
||||
]
|
||||
++ lib.optionals cudaSupport [
|
||||
# ideally we'd add this unconditionally too, but it doesn't work on darwin
|
||||
# we make this conditional on `cudaSupport` instead of the system, so that the hash for both
|
||||
# the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
|
||||
# have access to darwin machines
|
||||
"--config=cuda"
|
||||
];
|
||||
|
||||
sha256 =
|
||||
(
|
||||
if cudaSupport
|
||||
then {x86_64-linux = "sha256-vUoAPkYKEnHkV4fw6BI0mCeuP2e8BMCJnVuZMm9LwSA=";}
|
||||
else {
|
||||
x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
|
||||
aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
|
||||
}
|
||||
)
|
||||
.${effectiveStdenv.system}
|
||||
or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
|
||||
};
|
||||
|
||||
buildAttrs = {
|
||||
outputs = ["out"];
|
||||
|
||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," (
|
||||
tf_system_libs
|
||||
++ lib.optionals (!effectiveStdenv.isDarwin) [
|
||||
"nsync" # fails to build on darwin
|
||||
]
|
||||
);
|
||||
};
|
||||
|
||||
inherit meta;
|
||||
};
|
||||
platformTag =
|
||||
if effectiveStdenv.hostPlatform.isLinux
|
||||
then "manylinux2014_${arch}"
|
||||
else if effectiveStdenv.system == "x86_64-darwin"
|
||||
then "macosx_10_9_${arch}"
|
||||
else if effectiveStdenv.system == "aarch64-darwin"
|
||||
then "macosx_11_0_${arch}"
|
||||
else throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
|
||||
in
|
||||
buildPythonPackage {
|
||||
inherit meta pname version;
|
||||
format = "wheel";
|
||||
|
||||
src = let
|
||||
cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
|
||||
in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
|
||||
|
||||
# Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
|
||||
# for more info.
|
||||
postInstall = lib.optionalString cudaSupport ''
|
||||
mkdir -p $out/bin
|
||||
ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
|
||||
|
||||
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
|
||||
patchelf --add-rpath "${
|
||||
lib.makeLibraryPath [
|
||||
cuda_libs_joined
|
||||
cudnn
|
||||
nccl
|
||||
]
|
||||
}" "$lib"
|
||||
done
|
||||
'';
|
||||
|
||||
nativeBuildInputs = lib.optionals cudaSupport [autoAddDriverRunpath];
|
||||
|
||||
dependencies = [
|
||||
absl-py
|
||||
curl
|
||||
double-conversion
|
||||
flatbuffers
|
||||
giflib
|
||||
jsoncpp
|
||||
libjpeg_turbo
|
||||
ml-dtypes
|
||||
numpy
|
||||
scipy
|
||||
six
|
||||
snappy
|
||||
];
|
||||
|
||||
pythonImportsCheck = [
|
||||
"jaxlib"
|
||||
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
|
||||
"jaxlib.cpu_feature_guard"
|
||||
"jaxlib.xla_client"
|
||||
];
|
||||
|
||||
# Without it there are complaints about libcudart.so.11.0 not being found
|
||||
# because RPATH path entries added above are stripped.
|
||||
dontPatchELF = cudaSupport;
|
||||
}
|
26
pkgs/realtime-stt-server.nix
Normal file
26
pkgs/realtime-stt-server.nix
Normal file
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
lib,
|
||||
python3,
|
||||
stdenv,
|
||||
}:
|
||||
stdenv.mkDerivation {
|
||||
pname = "realtime-stt-server";
|
||||
version = "1.0.0";
|
||||
|
||||
dontUnpack = true;
|
||||
propagatedBuildInputs = [
|
||||
(python3.withPackages (pythonPackages: with pythonPackages; [realtime-stt]))
|
||||
];
|
||||
|
||||
installPhase = ''
|
||||
install -Dm755 ${./realtime-stt-server.py} $out/bin/realtime-stt-server
|
||||
'';
|
||||
|
||||
meta = {
|
||||
description = "";
|
||||
homepage = "";
|
||||
license = lib.licenses.mit;
|
||||
maintainers = with lib.maintainers; [oddlama];
|
||||
mainProgram = "realtime-stt-server";
|
||||
};
|
||||
}
|
258
pkgs/realtime-stt-server.py
Normal file
258
pkgs/realtime-stt-server.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import queue
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
import sys
|
||||
import threading
|
||||
|
||||
def send_message(sock, message):
|
||||
message_str = json.dumps(message)
|
||||
message_bytes = message_str.encode("utf-8")
|
||||
message_length = len(message_bytes)
|
||||
sock.sendall(struct.pack("!I", message_length))
|
||||
sock.sendall(message_bytes)
|
||||
|
||||
def recv_message(sock):
|
||||
length_bytes = sock.recv(4)
|
||||
if not length_bytes or len(length_bytes) == 0:
|
||||
return None
|
||||
message_length = struct.unpack("!I", length_bytes)[0]
|
||||
if message_length & 0x80000000 != 0:
|
||||
# Raw audio data
|
||||
message_length &= ~0x80000000
|
||||
message_bytes = sock.recv(message_length)
|
||||
return message_bytes
|
||||
|
||||
message_bytes = sock.recv(message_length)
|
||||
message_str = message_bytes.decode("utf-8")
|
||||
return json.loads(message_str)
|
||||
|
||||
class Client:
|
||||
def __init__(self, tag, conn):
|
||||
self.tag = tag
|
||||
self.conn = conn
|
||||
self.thread = threading.current_thread()
|
||||
self.mode = None
|
||||
self.is_true_client = False
|
||||
self.waiting = False
|
||||
self.queue = queue.Queue()
|
||||
|
||||
clients = {}
|
||||
active_client = None
|
||||
model_lock = threading.Lock()
|
||||
|
||||
def publish(obj, client=None):
|
||||
msg = json.dumps(obj)
|
||||
if client is None:
|
||||
for c in clients.values():
|
||||
if c.mode == "status":
|
||||
c.queue.put(msg)
|
||||
else:
|
||||
client.queue.put(msg)
|
||||
|
||||
def refresh_status(client=None):
|
||||
publish(dict(refresh_status=True), client=client)
|
||||
|
||||
def handle_client(conn, addr):
|
||||
global recorder
|
||||
global active_client
|
||||
tag = f"{addr[0]}:{addr[1]}"
|
||||
client = Client(tag, conn)
|
||||
clients[addr] = client
|
||||
|
||||
try:
|
||||
logger.info(f'{tag} Connected to client')
|
||||
init = recv_message(conn)
|
||||
logger.info(f'{tag} Client requested mode {init["mode"]}')
|
||||
client.mode = init["mode"]
|
||||
client.is_true_client = init["mode"] == "stream"
|
||||
|
||||
if init["mode"] == "status":
|
||||
refresh_status(client) # refresh once after startup
|
||||
|
||||
while True:
|
||||
message = json.loads(client.queue.get())
|
||||
|
||||
if "refresh_status" in message and message["refresh_status"] == True:
|
||||
n_clients = len(list(filter(lambda x: x.is_true_client, clients.values())))
|
||||
n_waiting = len(list(filter(lambda x: x.is_true_client and x.waiting, clients.values())))
|
||||
status = {
|
||||
"clients": n_clients,
|
||||
"waiting": n_waiting,
|
||||
}
|
||||
send_message(conn, status)
|
||||
client.queue.task_done()
|
||||
else:
|
||||
logger.info(f'{tag} Acquiring lock')
|
||||
client.waiting = True
|
||||
refresh_status()
|
||||
send_message(conn, dict(status="waiting for lock"))
|
||||
|
||||
with model_lock:
|
||||
active_client = client
|
||||
client.waiting = False
|
||||
refresh_status()
|
||||
send_message(conn, dict(status="lock acquired"))
|
||||
recorder.start()
|
||||
|
||||
def send_queue():
|
||||
try:
|
||||
while True:
|
||||
message = client.queue.get()
|
||||
if message is None:
|
||||
return
|
||||
send_message(conn, message)
|
||||
client.queue.task_done()
|
||||
except (OSError, ConnectionError):
|
||||
logger.info(f"{tag} error in send queue: connection closed?")
|
||||
|
||||
sender_thread = threading.Thread(target=send_queue)
|
||||
sender_thread.daemon = True
|
||||
sender_thread.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
msg = recv_message(conn)
|
||||
if msg is None:
|
||||
break
|
||||
|
||||
if isinstance(msg, bytes):
|
||||
recorder.feed_audio(msg)
|
||||
continue
|
||||
|
||||
if "action" in msg and msg["action"] == "flush":
|
||||
logger.info(f"{tag} flushing on client request")
|
||||
# input some silence
|
||||
for i in range(10):
|
||||
recorder.feed_audio(bytes(1000))
|
||||
recorder.stop()
|
||||
logger.info(f"{tag} flushed")
|
||||
continue
|
||||
else:
|
||||
logger.info(f"{tag} error in recv: invalid message: {msg}")
|
||||
continue
|
||||
except (OSError, ConnectionError):
|
||||
logger.info(f"{tag} error in recv: connection closed?")
|
||||
finally:
|
||||
client.queue.put(None)
|
||||
active_client = None
|
||||
recorder.stop()
|
||||
sender_thread.join()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f'{tag} Error handling client: {e}')
|
||||
finally:
|
||||
refresh_status()
|
||||
del clients[addr]
|
||||
conn.close()
|
||||
logger.info(f'{tag} Connection closed')
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(format="%(levelname)s %(message)s")
|
||||
logger = logging.getLogger("realtime-stt-server")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
#logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default='localhost')
|
||||
parser.add_argument("--port", type=int, default=43007)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info("Importing runtime")
|
||||
from RealtimeSTT import AudioToTextRecorder
|
||||
|
||||
def text_detected(ts):
|
||||
text, segments = ts
|
||||
global active_client
|
||||
if active_client is not None:
|
||||
segments = [x._asdict() for x in segments]
|
||||
active_client.queue.put(dict(kind="realtime", text=text, segments=segments))
|
||||
|
||||
recorder_ready = threading.Event()
|
||||
recorder_config = {
|
||||
'init_logging': False,
|
||||
|
||||
'use_microphone': False,
|
||||
'spinner': False,
|
||||
'model': 'large-v3',
|
||||
'return_segments': True,
|
||||
#'language': 'en',
|
||||
|
||||
'silero_sensitivity': 0.4,
|
||||
'webrtc_sensitivity': 2,
|
||||
'post_speech_silence_duration': 0.7,
|
||||
'min_length_of_recording': 0.0,
|
||||
'min_gap_between_recordings': 0,
|
||||
|
||||
'enable_realtime_transcription': True,
|
||||
'realtime_processing_pause': 0,
|
||||
'realtime_model_type': 'base',
|
||||
|
||||
'on_realtime_transcription_stabilized': text_detected,
|
||||
}
|
||||
|
||||
def recorder_thread():
|
||||
global recorder
|
||||
global active_client
|
||||
logger.info("Initializing RealtimeSTT...")
|
||||
recorder = AudioToTextRecorder(**recorder_config)
|
||||
logger.info("AudioToTextRecorder ready")
|
||||
recorder_ready.set()
|
||||
try:
|
||||
while not recorder.is_shut_down:
|
||||
text, segments = recorder.text()
|
||||
if text == "":
|
||||
continue
|
||||
if active_client is not None:
|
||||
segments = [x._asdict() for x in segments]
|
||||
active_client.queue.put(dict(kind="result", text=text, segments=segments))
|
||||
except (OSError, EOFError) as e:
|
||||
logger.info(f"recorder thread failed: {e}")
|
||||
return
|
||||
|
||||
recorder_thread = threading.Thread(target=recorder_thread)
|
||||
recorder_thread.start()
|
||||
recorder_ready.wait()
|
||||
|
||||
logger.info(f'Starting server on {args.host}:{args.port}')
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind((args.host, args.port))
|
||||
s.listen()
|
||||
logger.info(f'Server ready to accept connections')
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Accept incoming connection
|
||||
conn, addr = s.accept()
|
||||
conn.setblocking(True)
|
||||
|
||||
# Create a new thread to handle the client
|
||||
client_thread = threading.Thread(target=handle_client, args=(conn, addr))
|
||||
client_thread.daemon = True # die with main thread
|
||||
client_thread.start()
|
||||
|
||||
# Note: The main thread continues to accept new connections
|
||||
except KeyboardInterrupt:
|
||||
logger.info(f'Received shutdown request')
|
||||
|
||||
for c in clients.values():
|
||||
try:
|
||||
c.conn.close()
|
||||
except (OSError, ConnectionError):
|
||||
pass
|
||||
try:
|
||||
s.close()
|
||||
except (OSError, ConnectionError):
|
||||
pass
|
||||
|
||||
recorder.shutdown()
|
||||
|
||||
logger.info('Server terminated')
|
79
pkgs/realtime-stt.nix
Normal file
79
pkgs/realtime-stt.nix
Normal file
|
@ -0,0 +1,79 @@
|
|||
{
|
||||
lib,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
setuptools,
|
||||
wheel,
|
||||
faster-whisper,
|
||||
pyaudio,
|
||||
scipy,
|
||||
torch,
|
||||
torchaudio,
|
||||
webrtcvad,
|
||||
websockets,
|
||||
}:
|
||||
buildPythonPackage rec {
|
||||
pname = "realtime-stt";
|
||||
version = "0.1.16";
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "oddlama";
|
||||
repo = "RealtimeSTT";
|
||||
rev = "master";
|
||||
hash = "sha256-64RE/aT5PxuFFUTvjNefqTlAKWG1fftKV0wcY/hFlcg=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
setuptools
|
||||
wheel
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
faster-whisper
|
||||
pyaudio
|
||||
scipy
|
||||
torch
|
||||
torchaudio
|
||||
webrtcvad
|
||||
websockets
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
# Remove unneded modules
|
||||
substituteInPlace RealtimeSTT/audio_recorder.py \
|
||||
--replace-fail 'import pvporcupine' "" \
|
||||
--replace-fail 'import halo' ""
|
||||
'';
|
||||
|
||||
preBuild = ''
|
||||
cat > setup.py << EOF
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='realtime-stt',
|
||||
packages=['RealtimeSTT'],
|
||||
version='${version}',
|
||||
install_requires=[
|
||||
"PyAudio",
|
||||
"faster-whisper",
|
||||
#"pvporcupine",
|
||||
"webrtcvad",
|
||||
"#halo",
|
||||
"torch",
|
||||
"torchaudio",
|
||||
"scipy",
|
||||
"websockets",
|
||||
],
|
||||
)
|
||||
EOF
|
||||
'';
|
||||
|
||||
pythonImportsCheck = ["RealtimeSTT"];
|
||||
|
||||
meta = {
|
||||
description = "A robust, efficient, low-latency speech-to-text library with advanced voice activity detection, wake word activation and instant transcription";
|
||||
homepage = "https://github.com/KoljaB/RealtimeSTT";
|
||||
license = lib.licenses.mit;
|
||||
maintainers = with lib.maintainers; [oddlama];
|
||||
};
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue