65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
import json
|
|
import yaml
|
|
import websocket
|
|
import os
|
|
|
|
BASE_DIR = '../janco_website/content/posts/'
|
|
OUTPUT_FILE = './output_shit.json'
|
|
|
|
files = os.scandir(BASE_DIR)
|
|
|
|
with open(OUTPUT_FILE, 'w') as of:
|
|
of.write("[\n")
|
|
|
|
def write_result(data):
|
|
with open(OUTPUT_FILE, 'a') as of:
|
|
of.write(json.dumps(data))
|
|
|
|
def send_msg(sess, data):
|
|
print("Send_msg", data)
|
|
sess.send(json.dumps(data))
|
|
|
|
def talk(sess, inputs):
|
|
send_msg(sess, {"type":"generate","inputs":inputs,"max_new_tokens":1,"stop_sequence":"</s>","extra_stop_sequences":["\n\nHuman"],"do_sample":1,"temperature":0.9,"top_k":40})
|
|
|
|
def collect_response(sess):
|
|
all_text = ""
|
|
while True:
|
|
raw = sess.recv()
|
|
print(raw)
|
|
res = json.loads(raw)
|
|
if 'stop' in res and res['stop']:
|
|
return all_text
|
|
all_text += res['outputs']
|
|
|
|
# https://websocket-client.readthedocs.io/en/latest/examples.html
|
|
sess = websocket.WebSocket()
|
|
sess.connect("ws://chat.petals.ml/api/v2/generate")
|
|
|
|
# Open model
|
|
send_msg(sess, {"type":"open_inference_session","model":"bigscience/bloomz-petals","max_length":1024})
|
|
print(sess.recv())
|
|
|
|
send_msg(sess, {"type":"generate","inputs":"A human talks to a powerful AI that follows the human's instructions.\n\nHuman: Hi!\n\nAI: Hi! How can I help you?</s>Human: Hello\n\nAI:","max_new_tokens":1,"stop_sequence":"</s>","extra_stop_sequences":["\n\nHuman"],"do_sample":1,"temperature":0.9,"top_k":40})
|
|
collect_response(sess)
|
|
|
|
|
|
for fname in files:
|
|
print("=========")
|
|
print(f"opening {fname.path}")
|
|
with open(fname.path) as inp_f:
|
|
file_content = inp_f.read()
|
|
compos = file_content.split('---')
|
|
headers = yaml.safe_load(compos[1])
|
|
content = compos[2].strip()
|
|
|
|
print(headers)
|
|
print(content)
|
|
prefix = "Donnez un titre descriptif résumant ce texte :"
|
|
send_msg(sess, {"type":"generate","inputs": "Human: " + prefix + content + "\nAI:","max_new_tokens":1,"stop_sequence":"</s>","extra_stop_sequences":["\n\nHuman"],"do_sample":1,"temperature":0.9,"top_k":40})
|
|
generated_title = collect_response(sess)
|
|
print(generated_title)
|
|
write_result({"on_post_file_path": fname.path, "generated_title": generated_title, "on_post": headers})
|
|
|
|
with open(OUTPUT_FILE, 'a') as of:
|
|
of.write("\n]\n")
|