import gradio as gr from huggingface_hub import InferenceClient import ast import networkx as nx import matplotlib.pyplot as plt client = InferenceClient("Qwen/Qwen2.5-72B-Instruct") def sampling(num_samples, num_associations): outputs = ast.literal_eval(client.chat.completions.create( messages=[ {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n" "words: []\n" f"{num_samples} samples in a list" }, {"role": "user", "content": f"synthesize {num_samples} random but widespread words for semantic modeling"}, ], response_format={ "type": "json", "value": { "properties": { "words": {"type": "array", "items": {"type": "string"}}, } } }, stream=False, max_tokens=1024, temperature=0.7, top_p=0.1 ).choices[0].get('message')['content']) fields = {} for word in outputs['words']: fields[word] = ast.literal_eval(client.chat.completions.create( messages=[ {"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n' 'associations: []' }, {"role": "user", "content": f"synthesize {num_associations} associations for the word {word}"}, ], response_format={ "type": "json", "value": { "properties": { "associations": {"type": "array", "items": {"type": "string"}} } } }, stream=False, max_tokens=2000, temperature=0.7, top_p=0.1 ).choices[0].get('message')['content']) triplets = [] for cluster in fields: for association in fields[cluster]['associations']: triplets.append(ast.literal_eval(client.chat.completions.create( messages=[ {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n" "properties: [subject, predicate, object]\n" "use chain-of-thought for predictions" }, {"role": "user", "content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"}, ], response_format={ "type": "json", "value": { "properties": { "properties": {"type": "array", "items": {"type": "string"}} } } }, stream=False, max_tokens=128, temperature=0.7, top_p=0.1 ).choices[0].get('message')['content'])) G = nx.DiGraph() for entry in triplets: source, label, target = entry['properties'] G.add_node(source, label=source) G.add_node(target, label=target) G.add_edge(source, target, label=label) pos = nx.spring_layout(G) nx.draw_networkx_nodes(G, pos, node_size=500, node_color='lightblue') edge_labels = nx.get_edge_attributes(G, 'label') # Get edge labels nx.draw_networkx_edges(G, pos, arrowstyle='->', arrowsize=25) nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) node_labels = nx.get_node_attributes(G, 'label') # Get node labels nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_family="sans-serif") plt.axis('off') plt.tight_layout() plt.savefig('synthnet.png') plt.close() return 'synthnet.png' demo = gr.Interface( inputs=[ gr.Slider(minimum=1, maximum=256, label="Number of Samples"), gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"), ], fn=sampling, outputs=gr.Image(type="filepath"), title="SynthNet", description="Select a number of samples and associations to each sample to generate a graph.", ) if __name__ == "__main__": demo.launch(share=True)