From 00d9b231e1cdef10c3fea7e6cc0b485bf4208031 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?David=20Lanzend=C3=B6rfer?= <leviathan@libresilicon.com>
Date: Thu, 4 Jul 2024 23:20:17 +0100
Subject: [PATCH] Inference test

---
 src/py/tty4.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 47 insertions(+)
 create mode 100644 src/py/tty4.py

diff --git a/src/py/tty4.py b/src/py/tty4.py
new file mode 100644
index 0000000..4f2e9a4
--- /dev/null
+++ b/src/py/tty4.py
@@ -0,0 +1,47 @@
+import json
+
+from fac_tools import run_command
+from fac_tools import get_fac_wrapper
+from fac_tools import load_weights_and_biases
+from fac_tools import predict_next_token
+from fac_tools import reset_states
+
+#from transformers import AutoTokenizer
+
+#tokenizer = AutoTokenizer.from_pretrained("gpt2")
+#help(tokenizer)
+
+#prompt = "Taxation"
+#input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+#tokens = input_ids[0].numpy()
+tokens = [27017, 341, 318, 12402] # Taxation is theft
+
+server = get_fac_wrapper("telnet")
+
+print("Reading in JSON")
+with open("test_files/weights_and_biases_trained.json", "r") as f:
+    weights_and_biases = json.loads(f.read())
+    f.close()
+
+load_weights_and_biases(server, weights_and_biases)
+
+#print("Input string: ", prompt)
+#print("Input Tokens: ", tokens)
+
+#retarr=[]
+#ret = 0
+
+reset_states(server)
+
+print("Input Token: ", hex(tokens[0]))
+tok = predict_next_token(server,tokens[0])
+print("Output Token: ", hex(tok))
+print("Expected Token: ", hex(tokens[1]))
+
+#print("Output string: ", tokenizer.decode(retarr))
+
+run_command(server,"DONE")
+
+run_command(server,"TERMINATE")
+
+server.close()
-- 
GitLab