1
1
Fork 1
mirror of https://github.com/oddlama/nix-config.git synced 2025-10-10 23:00:39 +02:00

feat: add my own realtime-stt-server script; add fixed jaxlib

This commit is contained in:
oddlama 2024-06-20 17:51:32 +02:00
parent 5e9143778f
commit e1e8997525
No known key found for this signature in database
GPG key ID: 14EFE510775FE39A
5 changed files with 823 additions and 0 deletions

View file

@ -33,6 +33,15 @@
wrapProgram $out/bin/nvim --add-flags "--clean"
'';
});
pythonPackagesExtensions =
prev.pythonPackagesExtensions
++ [
(pythonFinal: _pythonPrev: {
jaxlib = pythonFinal.callPackage ./jaxlib.nix {};
realtime-stt = pythonFinal.callPackage ./realtime-stt.nix {};
})
];
realtime-stt-server = prev.callPackage ./realtime-stt-server.nix {};
formats =
prev.formats

451
pkgs/jaxlib.nix Normal file
View file

@ -0,0 +1,451 @@
{
lib,
pkgs,
# Build-time dependencies:
autoAddDriverRunpath,
bazel_6,
binutils,
buildBazelPackage,
buildPythonPackage,
curl,
cython,
fetchFromGitHub,
git,
jsoncpp,
nsync,
openssl,
pybind11,
setuptools,
symlinkJoin,
wheel,
build,
which,
# Python dependencies:
absl-py,
flatbuffers,
ml-dtypes,
numpy,
scipy,
six,
# Runtime dependencies:
double-conversion,
giflib,
libjpeg_turbo,
python,
snappy,
zlib,
config,
# CUDA flags:
cudaSupport ? config.cudaSupport,
cudaPackages,
# MKL:
mklSupport ? true,
} @ inputs: let
inherit
(cudaPackages)
cudaFlags
cudaVersion
cudnn
nccl
;
pname = "jaxlib";
version = "0.4.28";
# It's necessary to consistently use backendStdenv when building with CUDA
# support, otherwise we get libstdc++ errors downstream
stdenv = throw "Use effectiveStdenv instead";
effectiveStdenv =
if cudaSupport
then cudaPackages.backendStdenv
else inputs.stdenv;
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research";
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ndl];
platforms = platforms.unix;
# aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
# however even with that fix applied, it doesn't work for everyone:
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
# NOTE: We always build with NCCL; if it is unsupported, then our build is broken.
broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
};
# These are necessary at build time and run time.
cuda_libs_joined = symlinkJoin {
name = "cuda-joined";
paths = with cudaPackages; [
cuda_cudart.lib # libcudart.so
cuda_cudart.static # libcudart_static.a
cuda_cupti.lib # libcupti.so
libcublas.lib # libcublas.so
libcufft.lib # libcufft.so
libcurand.lib # libcurand.so
libcusolver.lib # libcusolver.so
libcusparse.lib # libcusparse.so
];
};
# These are only necessary at build time.
cuda_build_deps_joined = symlinkJoin {
name = "cuda-build-deps-joined";
paths = with cudaPackages; [
cuda_libs_joined
# Binaries
cudaPackages.cuda_nvcc.bin # nvcc
# Headers
cuda_cccl.dev # block_load.cuh
cuda_cudart.dev # cuda.h
cuda_cupti.dev # cupti.h
cuda_nvcc.dev # See https://github.com/google/jax/issues/19811
cuda_nvml_dev # nvml.h
cuda_nvtx.dev # nvToolsExt.h
libcublas.dev # cublas_api.h
libcufft.dev # cufft.h
libcurand.dev # curand.h
libcusolver.dev # cusolver_common.h
libcusparse.dev # cusparse.h
];
};
backend_cc_joined = symlinkJoin {
name = "cuda-cc-joined";
paths = [
effectiveStdenv.cc
binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
];
};
# Copy-paste from TF derivation.
# Most of these are not really used in jaxlib compilation but it's simpler to keep it
# 'as is' so that it's more compatible with TF derivation.
tf_system_libs = [
"absl_py"
"astor_archive"
"astunparse_archive"
# Not packaged in nixpkgs
# "com_github_googleapis_googleapis"
# "com_github_googlecloudplatform_google_cloud_cpp"
# Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
# "com_github_grpc_grpc"
# ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
# target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
# "com_google_protobuf"
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
# "com_googlesource_code_re2"
"curl"
"cython"
"dill_archive"
"double_conversion"
"flatbuffers"
"functools32_archive"
"gast_archive"
"gif"
"hwloc"
"icu"
"jsoncpp_git"
"libjpeg_turbo"
"lmdb"
"nasm"
"opt_einsum_archive"
"org_sqlite"
"pasta"
"png"
# ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
# target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
# "pybind11"
"six_archive"
"snappy"
"tblib_archive"
"termcolor_archive"
"typing_extensions_archive"
"wrapt"
"zlib"
];
arch =
# KeyError: ('Linux', 'arm64')
if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64"
then "aarch64"
else effectiveStdenv.hostPlatform.linuxArch;
xla = effectiveStdenv.mkDerivation {
pname = "xla-src";
version = "unstable";
src = fetchFromGitHub {
owner = "openxla";
repo = "xla";
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
};
dontBuild = true;
# This is necessary for patchShebangs to know the right path to use.
nativeBuildInputs = [python];
# Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
postPatch = ''
patchShebangs .
'';
installPhase = ''
cp -r . $out
'';
};
bazel-build = buildBazelPackage rec {
name = "bazel-build-${pname}-${version}";
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
bazel = bazel_6;
src = fetchFromGitHub {
owner = "google";
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
};
nativeBuildInputs = [
cython
pkgs.flatbuffers
git
setuptools
wheel
build
which
];
buildInputs =
[
curl
double-conversion
giflib
jsoncpp
libjpeg_turbo
numpy
openssl
pkgs.flatbuffers
pkgs.protobuf
pybind11
scipy
six
snappy
zlib
]
++ lib.optionals (!effectiveStdenv.isDarwin) [nsync];
# We don't want to be quite so picky regarding bazel version
postPatch = ''
rm -f .bazelversion
'';
bazelRunTarget = "//jaxlib/tools:build_wheel";
runTargetFlags = [
"--output_path=$out"
"--cpu=${arch}"
# This has no impact whatsoever...
"--jaxlib_git_hash='12345678'"
];
removeRulesCC = false;
GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin";
GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc";
# The version is automatically set to ".dev" if this variable is not set.
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
JAXLIB_RELEASE = "1";
preConfigure =
# Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
''
mkdir dummy-ldconfig
echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
chmod +x dummy-ldconfig/ldconfig
export PATH="$PWD/dummy-ldconfig:$PATH"
''
+
# Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
# for more info. We assume
# * `cpu = None`
# * `enable_nccl = True`
# * `target_cpu_features = "release"`
# * `rocm_amdgpu_targets = None`
# * `enable_rocm = False`
# * `build_gpu_plugin = False`
# * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
#
# Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
# instead of duplicating the logic here. Perhaps we can leverage the
# `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
''
cat <<CFG > ./.jax_configure.bazelrc
build --strategy=Genrule=standalone
build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
build --action_env=PYENV_ROOT
build --python_path="${python}/bin/python"
build --distinct_host_configuration=false
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
''
+ lib.optionalString cudaSupport ''
build --config=cuda
build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnn},${nccl}"
build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}"
build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
''
+
# Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
# rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
# good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
# for upstream's version.
lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix)
''
build --config=avx_posix
''
+ lib.optionalString mklSupport ''
build --config=mkl_open_source_only
''
+ ''
CFG
'';
# Make sure Bazel knows about our configuration flags during fetching so that the
# relevant dependencies can be downloaded.
bazelFlags =
[
"-c opt"
# See https://bazel.build/external/advanced#overriding-repositories for
# information on --override_repository flag.
"--override_repository=xla=${xla}"
]
++ lib.optionals effectiveStdenv.cc.isClang [
# bazel depends on the compiler frontend automatically selecting these flags based on file
# extension but our clang doesn't.
# https://github.com/NixOS/nixpkgs/issues/150655
"--cxxopt=-x"
"--cxxopt=c++"
"--host_cxxopt=-x"
"--host_cxxopt=c++"
];
# We intentionally overfetch so we can share the fetch derivation across all the different configurations
fetchAttrs = {
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
# we have to force @mkl_dnn_v1 since it's not needed on darwin
bazelTargets = [
bazelRunTarget
"@mkl_dnn_v1//:mkl_dnn"
];
bazelFlags =
bazelFlags
++ [
"--config=avx_posix"
"--config=mkl_open_source_only"
]
++ lib.optionals cudaSupport [
# ideally we'd add this unconditionally too, but it doesn't work on darwin
# we make this conditional on `cudaSupport` instead of the system, so that the hash for both
# the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
# have access to darwin machines
"--config=cuda"
];
sha256 =
(
if cudaSupport
then {x86_64-linux = "sha256-vUoAPkYKEnHkV4fw6BI0mCeuP2e8BMCJnVuZMm9LwSA=";}
else {
x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
}
)
.${effectiveStdenv.system}
or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
};
buildAttrs = {
outputs = ["out"];
TF_SYSTEM_LIBS = lib.concatStringsSep "," (
tf_system_libs
++ lib.optionals (!effectiveStdenv.isDarwin) [
"nsync" # fails to build on darwin
]
);
};
inherit meta;
};
platformTag =
if effectiveStdenv.hostPlatform.isLinux
then "manylinux2014_${arch}"
else if effectiveStdenv.system == "x86_64-darwin"
then "macosx_10_9_${arch}"
else if effectiveStdenv.system == "aarch64-darwin"
then "macosx_11_0_${arch}"
else throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
in
buildPythonPackage {
inherit meta pname version;
format = "wheel";
src = let
cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
# Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
# for more info.
postInstall = lib.optionalString cudaSupport ''
mkdir -p $out/bin
ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
patchelf --add-rpath "${
lib.makeLibraryPath [
cuda_libs_joined
cudnn
nccl
]
}" "$lib"
done
'';
nativeBuildInputs = lib.optionals cudaSupport [autoAddDriverRunpath];
dependencies = [
absl-py
curl
double-conversion
flatbuffers
giflib
jsoncpp
libjpeg_turbo
ml-dtypes
numpy
scipy
six
snappy
];
pythonImportsCheck = [
"jaxlib"
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
"jaxlib.cpu_feature_guard"
"jaxlib.xla_client"
];
# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.
dontPatchELF = cudaSupport;
}

View file

@ -0,0 +1,26 @@
{
lib,
python3,
stdenv,
}:
stdenv.mkDerivation {
pname = "realtime-stt-server";
version = "1.0.0";
dontUnpack = true;
propagatedBuildInputs = [
(python3.withPackages (pythonPackages: with pythonPackages; [realtime-stt]))
];
installPhase = ''
install -Dm755 ${./realtime-stt-server.py} $out/bin/realtime-stt-server
'';
meta = {
description = "";
homepage = "";
license = lib.licenses.mit;
maintainers = with lib.maintainers; [oddlama];
mainProgram = "realtime-stt-server";
};
}

258
pkgs/realtime-stt-server.py Normal file
View file

@ -0,0 +1,258 @@
#!/usr/bin/env python3
import argparse
import json
import logging
import numpy as np
import queue
import socket
import struct
import time
import sys
import threading
def send_message(sock, message):
message_str = json.dumps(message)
message_bytes = message_str.encode("utf-8")
message_length = len(message_bytes)
sock.sendall(struct.pack("!I", message_length))
sock.sendall(message_bytes)
def recv_message(sock):
length_bytes = sock.recv(4)
if not length_bytes or len(length_bytes) == 0:
return None
message_length = struct.unpack("!I", length_bytes)[0]
if message_length & 0x80000000 != 0:
# Raw audio data
message_length &= ~0x80000000
message_bytes = sock.recv(message_length)
return message_bytes
message_bytes = sock.recv(message_length)
message_str = message_bytes.decode("utf-8")
return json.loads(message_str)
class Client:
def __init__(self, tag, conn):
self.tag = tag
self.conn = conn
self.thread = threading.current_thread()
self.mode = None
self.is_true_client = False
self.waiting = False
self.queue = queue.Queue()
clients = {}
active_client = None
model_lock = threading.Lock()
def publish(obj, client=None):
msg = json.dumps(obj)
if client is None:
for c in clients.values():
if c.mode == "status":
c.queue.put(msg)
else:
client.queue.put(msg)
def refresh_status(client=None):
publish(dict(refresh_status=True), client=client)
def handle_client(conn, addr):
global recorder
global active_client
tag = f"{addr[0]}:{addr[1]}"
client = Client(tag, conn)
clients[addr] = client
try:
logger.info(f'{tag} Connected to client')
init = recv_message(conn)
logger.info(f'{tag} Client requested mode {init["mode"]}')
client.mode = init["mode"]
client.is_true_client = init["mode"] == "stream"
if init["mode"] == "status":
refresh_status(client) # refresh once after startup
while True:
message = json.loads(client.queue.get())
if "refresh_status" in message and message["refresh_status"] == True:
n_clients = len(list(filter(lambda x: x.is_true_client, clients.values())))
n_waiting = len(list(filter(lambda x: x.is_true_client and x.waiting, clients.values())))
status = {
"clients": n_clients,
"waiting": n_waiting,
}
send_message(conn, status)
client.queue.task_done()
else:
logger.info(f'{tag} Acquiring lock')
client.waiting = True
refresh_status()
send_message(conn, dict(status="waiting for lock"))
with model_lock:
active_client = client
client.waiting = False
refresh_status()
send_message(conn, dict(status="lock acquired"))
recorder.start()
def send_queue():
try:
while True:
message = client.queue.get()
if message is None:
return
send_message(conn, message)
client.queue.task_done()
except (OSError, ConnectionError):
logger.info(f"{tag} error in send queue: connection closed?")
sender_thread = threading.Thread(target=send_queue)
sender_thread.daemon = True
sender_thread.start()
try:
while True:
msg = recv_message(conn)
if msg is None:
break
if isinstance(msg, bytes):
recorder.feed_audio(msg)
continue
if "action" in msg and msg["action"] == "flush":
logger.info(f"{tag} flushing on client request")
# input some silence
for i in range(10):
recorder.feed_audio(bytes(1000))
recorder.stop()
logger.info(f"{tag} flushed")
continue
else:
logger.info(f"{tag} error in recv: invalid message: {msg}")
continue
except (OSError, ConnectionError):
logger.info(f"{tag} error in recv: connection closed?")
finally:
client.queue.put(None)
active_client = None
recorder.stop()
sender_thread.join()
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f'{tag} Error handling client: {e}')
finally:
refresh_status()
del clients[addr]
conn.close()
logger.info(f'{tag} Connection closed')
if __name__ == "__main__":
logging.basicConfig(format="%(levelname)s %(message)s")
logger = logging.getLogger("realtime-stt-server")
logger.setLevel(logging.DEBUG)
#logging.getLogger().setLevel(logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=43007)
args = parser.parse_args()
logger.info("Importing runtime")
from RealtimeSTT import AudioToTextRecorder
def text_detected(ts):
text, segments = ts
global active_client
if active_client is not None:
segments = [x._asdict() for x in segments]
active_client.queue.put(dict(kind="realtime", text=text, segments=segments))
recorder_ready = threading.Event()
recorder_config = {
'init_logging': False,
'use_microphone': False,
'spinner': False,
'model': 'large-v3',
'return_segments': True,
#'language': 'en',
'silero_sensitivity': 0.4,
'webrtc_sensitivity': 2,
'post_speech_silence_duration': 0.7,
'min_length_of_recording': 0.0,
'min_gap_between_recordings': 0,
'enable_realtime_transcription': True,
'realtime_processing_pause': 0,
'realtime_model_type': 'base',
'on_realtime_transcription_stabilized': text_detected,
}
def recorder_thread():
global recorder
global active_client
logger.info("Initializing RealtimeSTT...")
recorder = AudioToTextRecorder(**recorder_config)
logger.info("AudioToTextRecorder ready")
recorder_ready.set()
try:
while not recorder.is_shut_down:
text, segments = recorder.text()
if text == "":
continue
if active_client is not None:
segments = [x._asdict() for x in segments]
active_client.queue.put(dict(kind="result", text=text, segments=segments))
except (OSError, EOFError) as e:
logger.info(f"recorder thread failed: {e}")
return
recorder_thread = threading.Thread(target=recorder_thread)
recorder_thread.start()
recorder_ready.wait()
logger.info(f'Starting server on {args.host}:{args.port}')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((args.host, args.port))
s.listen()
logger.info(f'Server ready to accept connections')
try:
while True:
# Accept incoming connection
conn, addr = s.accept()
conn.setblocking(True)
# Create a new thread to handle the client
client_thread = threading.Thread(target=handle_client, args=(conn, addr))
client_thread.daemon = True # die with main thread
client_thread.start()
# Note: The main thread continues to accept new connections
except KeyboardInterrupt:
logger.info(f'Received shutdown request')
for c in clients.values():
try:
c.conn.close()
except (OSError, ConnectionError):
pass
try:
s.close()
except (OSError, ConnectionError):
pass
recorder.shutdown()
logger.info('Server terminated')

79
pkgs/realtime-stt.nix Normal file
View file

@ -0,0 +1,79 @@
{
lib,
buildPythonPackage,
fetchFromGitHub,
setuptools,
wheel,
faster-whisper,
pyaudio,
scipy,
torch,
torchaudio,
webrtcvad,
websockets,
}:
buildPythonPackage rec {
pname = "realtime-stt";
version = "0.1.16";
src = fetchFromGitHub {
owner = "oddlama";
repo = "RealtimeSTT";
rev = "master";
hash = "sha256-64RE/aT5PxuFFUTvjNefqTlAKWG1fftKV0wcY/hFlcg=";
};
nativeBuildInputs = [
setuptools
wheel
];
propagatedBuildInputs = [
faster-whisper
pyaudio
scipy
torch
torchaudio
webrtcvad
websockets
];
postPatch = ''
# Remove unneded modules
substituteInPlace RealtimeSTT/audio_recorder.py \
--replace-fail 'import pvporcupine' "" \
--replace-fail 'import halo' ""
'';
preBuild = ''
cat > setup.py << EOF
from setuptools import setup
setup(
name='realtime-stt',
packages=['RealtimeSTT'],
version='${version}',
install_requires=[
"PyAudio",
"faster-whisper",
#"pvporcupine",
"webrtcvad",
"#halo",
"torch",
"torchaudio",
"scipy",
"websockets",
],
)
EOF
'';
pythonImportsCheck = ["RealtimeSTT"];
meta = {
description = "A robust, efficient, low-latency speech-to-text library with advanced voice activity detection, wake word activation and instant transcription";
homepage = "https://github.com/KoljaB/RealtimeSTT";
license = lib.licenses.mit;
maintainers = with lib.maintainers; [oddlama];
};
}