1
1
Fork 1
mirror of https://github.com/oddlama/nix-config.git synced 2025-10-10 14:50:40 +02:00

feat: add some cmp keybinds, remove realtime-stt-server (now in whisper-overlay)

This commit is contained in:
oddlama 2024-06-21 22:07:05 +02:00
parent e1e8997525
commit 9b428f2480
No known key found for this signature in database
GPG key ID: 14EFE510775FE39A
6 changed files with 17 additions and 368 deletions

View file

@ -38,10 +38,8 @@
++ [
(pythonFinal: _pythonPrev: {
jaxlib = pythonFinal.callPackage ./jaxlib.nix {};
realtime-stt = pythonFinal.callPackage ./realtime-stt.nix {};
})
];
realtime-stt-server = prev.callPackage ./realtime-stt-server.nix {};
formats =
prev.formats

View file

@ -1,26 +0,0 @@
{
lib,
python3,
stdenv,
}:
stdenv.mkDerivation {
pname = "realtime-stt-server";
version = "1.0.0";
dontUnpack = true;
propagatedBuildInputs = [
(python3.withPackages (pythonPackages: with pythonPackages; [realtime-stt]))
];
installPhase = ''
install -Dm755 ${./realtime-stt-server.py} $out/bin/realtime-stt-server
'';
meta = {
description = "";
homepage = "";
license = lib.licenses.mit;
maintainers = with lib.maintainers; [oddlama];
mainProgram = "realtime-stt-server";
};
}

View file

@ -1,258 +0,0 @@
#!/usr/bin/env python3
import argparse
import json
import logging
import numpy as np
import queue
import socket
import struct
import time
import sys
import threading
def send_message(sock, message):
message_str = json.dumps(message)
message_bytes = message_str.encode("utf-8")
message_length = len(message_bytes)
sock.sendall(struct.pack("!I", message_length))
sock.sendall(message_bytes)
def recv_message(sock):
length_bytes = sock.recv(4)
if not length_bytes or len(length_bytes) == 0:
return None
message_length = struct.unpack("!I", length_bytes)[0]
if message_length & 0x80000000 != 0:
# Raw audio data
message_length &= ~0x80000000
message_bytes = sock.recv(message_length)
return message_bytes
message_bytes = sock.recv(message_length)
message_str = message_bytes.decode("utf-8")
return json.loads(message_str)
class Client:
def __init__(self, tag, conn):
self.tag = tag
self.conn = conn
self.thread = threading.current_thread()
self.mode = None
self.is_true_client = False
self.waiting = False
self.queue = queue.Queue()
clients = {}
active_client = None
model_lock = threading.Lock()
def publish(obj, client=None):
msg = json.dumps(obj)
if client is None:
for c in clients.values():
if c.mode == "status":
c.queue.put(msg)
else:
client.queue.put(msg)
def refresh_status(client=None):
publish(dict(refresh_status=True), client=client)
def handle_client(conn, addr):
global recorder
global active_client
tag = f"{addr[0]}:{addr[1]}"
client = Client(tag, conn)
clients[addr] = client
try:
logger.info(f'{tag} Connected to client')
init = recv_message(conn)
logger.info(f'{tag} Client requested mode {init["mode"]}')
client.mode = init["mode"]
client.is_true_client = init["mode"] == "stream"
if init["mode"] == "status":
refresh_status(client) # refresh once after startup
while True:
message = json.loads(client.queue.get())
if "refresh_status" in message and message["refresh_status"] == True:
n_clients = len(list(filter(lambda x: x.is_true_client, clients.values())))
n_waiting = len(list(filter(lambda x: x.is_true_client and x.waiting, clients.values())))
status = {
"clients": n_clients,
"waiting": n_waiting,
}
send_message(conn, status)
client.queue.task_done()
else:
logger.info(f'{tag} Acquiring lock')
client.waiting = True
refresh_status()
send_message(conn, dict(status="waiting for lock"))
with model_lock:
active_client = client
client.waiting = False
refresh_status()
send_message(conn, dict(status="lock acquired"))
recorder.start()
def send_queue():
try:
while True:
message = client.queue.get()
if message is None:
return
send_message(conn, message)
client.queue.task_done()
except (OSError, ConnectionError):
logger.info(f"{tag} error in send queue: connection closed?")
sender_thread = threading.Thread(target=send_queue)
sender_thread.daemon = True
sender_thread.start()
try:
while True:
msg = recv_message(conn)
if msg is None:
break
if isinstance(msg, bytes):
recorder.feed_audio(msg)
continue
if "action" in msg and msg["action"] == "flush":
logger.info(f"{tag} flushing on client request")
# input some silence
for i in range(10):
recorder.feed_audio(bytes(1000))
recorder.stop()
logger.info(f"{tag} flushed")
continue
else:
logger.info(f"{tag} error in recv: invalid message: {msg}")
continue
except (OSError, ConnectionError):
logger.info(f"{tag} error in recv: connection closed?")
finally:
client.queue.put(None)
active_client = None
recorder.stop()
sender_thread.join()
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f'{tag} Error handling client: {e}')
finally:
refresh_status()
del clients[addr]
conn.close()
logger.info(f'{tag} Connection closed')
if __name__ == "__main__":
logging.basicConfig(format="%(levelname)s %(message)s")
logger = logging.getLogger("realtime-stt-server")
logger.setLevel(logging.DEBUG)
#logging.getLogger().setLevel(logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=43007)
args = parser.parse_args()
logger.info("Importing runtime")
from RealtimeSTT import AudioToTextRecorder
def text_detected(ts):
text, segments = ts
global active_client
if active_client is not None:
segments = [x._asdict() for x in segments]
active_client.queue.put(dict(kind="realtime", text=text, segments=segments))
recorder_ready = threading.Event()
recorder_config = {
'init_logging': False,
'use_microphone': False,
'spinner': False,
'model': 'large-v3',
'return_segments': True,
#'language': 'en',
'silero_sensitivity': 0.4,
'webrtc_sensitivity': 2,
'post_speech_silence_duration': 0.7,
'min_length_of_recording': 0.0,
'min_gap_between_recordings': 0,
'enable_realtime_transcription': True,
'realtime_processing_pause': 0,
'realtime_model_type': 'base',
'on_realtime_transcription_stabilized': text_detected,
}
def recorder_thread():
global recorder
global active_client
logger.info("Initializing RealtimeSTT...")
recorder = AudioToTextRecorder(**recorder_config)
logger.info("AudioToTextRecorder ready")
recorder_ready.set()
try:
while not recorder.is_shut_down:
text, segments = recorder.text()
if text == "":
continue
if active_client is not None:
segments = [x._asdict() for x in segments]
active_client.queue.put(dict(kind="result", text=text, segments=segments))
except (OSError, EOFError) as e:
logger.info(f"recorder thread failed: {e}")
return
recorder_thread = threading.Thread(target=recorder_thread)
recorder_thread.start()
recorder_ready.wait()
logger.info(f'Starting server on {args.host}:{args.port}')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((args.host, args.port))
s.listen()
logger.info(f'Server ready to accept connections')
try:
while True:
# Accept incoming connection
conn, addr = s.accept()
conn.setblocking(True)
# Create a new thread to handle the client
client_thread = threading.Thread(target=handle_client, args=(conn, addr))
client_thread.daemon = True # die with main thread
client_thread.start()
# Note: The main thread continues to accept new connections
except KeyboardInterrupt:
logger.info(f'Received shutdown request')
for c in clients.values():
try:
c.conn.close()
except (OSError, ConnectionError):
pass
try:
s.close()
except (OSError, ConnectionError):
pass
recorder.shutdown()
logger.info('Server terminated')

View file

@ -1,79 +0,0 @@
{
lib,
buildPythonPackage,
fetchFromGitHub,
setuptools,
wheel,
faster-whisper,
pyaudio,
scipy,
torch,
torchaudio,
webrtcvad,
websockets,
}:
buildPythonPackage rec {
pname = "realtime-stt";
version = "0.1.16";
src = fetchFromGitHub {
owner = "oddlama";
repo = "RealtimeSTT";
rev = "master";
hash = "sha256-64RE/aT5PxuFFUTvjNefqTlAKWG1fftKV0wcY/hFlcg=";
};
nativeBuildInputs = [
setuptools
wheel
];
propagatedBuildInputs = [
faster-whisper
pyaudio
scipy
torch
torchaudio
webrtcvad
websockets
];
postPatch = ''
# Remove unneded modules
substituteInPlace RealtimeSTT/audio_recorder.py \
--replace-fail 'import pvporcupine' "" \
--replace-fail 'import halo' ""
'';
preBuild = ''
cat > setup.py << EOF
from setuptools import setup
setup(
name='realtime-stt',
packages=['RealtimeSTT'],
version='${version}',
install_requires=[
"PyAudio",
"faster-whisper",
#"pvporcupine",
"webrtcvad",
"#halo",
"torch",
"torchaudio",
"scipy",
"websockets",
],
)
EOF
'';
pythonImportsCheck = ["RealtimeSTT"];
meta = {
description = "A robust, efficient, low-latency speech-to-text library with advanced voice activity detection, wake word activation and instant transcription";
homepage = "https://github.com/KoljaB/RealtimeSTT";
license = lib.licenses.mit;
maintainers = with lib.maintainers; [oddlama];
};
}

View file

@ -99,10 +99,18 @@
connected-active = "<span foreground='red'></span>";
};
return-type = "json";
exec = "${lib.getExe pkgs.whisper-overlay} waybar-status --address localhost:43007";
#exec = "${lib.getExe pkgs.whisper-overlay} waybar-status";
on-click-middle = lib.getExe (pkgs.writeShellApplication {
name = "restart-whisper-overlay";
runtimeInputs = [];
# FIXME: TODO and use libnotify
text = ''
'';
});
on-click-right = lib.getExe (pkgs.writeShellApplication {
name = "toggle-realtime-stt-server";
runtimeInputs = [];
# FIXME: TODO and use libnotify
text = ''
'';
});

View file

@ -62,7 +62,7 @@
end
if cmp.visible() then
cmp.select_next_item()
cmp.select_next_item({ behavior = cmp.SelectBehavior.Select })
elseif require("luasnip").expandable() then
require("luasnip").expand()
elseif require("luasnip").expand_or_locally_jumpable() then
@ -80,12 +80,18 @@
"<Up>" =
# lua
''cmp.mapping(cmp.mapping.select_prev_item({ behavior = cmp.SelectBehavior.Select }), {'i'})'';
"<PageDown>" =
# lua
''cmp.mapping(cmp.mapping.select_next_item({ behavior = cmp.SelectBehavior.Select, count = -10 }), {'i'})'';
"<PageUp>" =
# lua
''cmp.mapping(cmp.mapping.select_prev_item({ behavior = cmp.SelectBehavior.Select, count = 10 }), {'i'})'';
"<S-Tab>" =
# lua
''
cmp.mapping(function(fallback)
if cmp.visible() then
cmp.select_prev_item()
cmp.select_prev_item({ behavior = cmp.SelectBehavior.Select })
elseif luasnip.jumpable(-1) then
luasnip.jump(-1)
else