In [ ]:
# Run this cell in Google Colab to retrieve the data
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/lh.pial.obj
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/icbm_fiber_mat.txt
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/fs_region_centers_68_sort.txt
!wget https://github.com/matteomancini/neurosnippets/raw/master/brainviz/interactive-network/freesurfer_regions_68_sort_full.txt
In [1]:
import numpy as np
import plotly.graph_objects as go
In [2]:
def obj_data_to_mesh3d(odata):
    # odata is the string read from an obj file
    vertices = []
    faces = []
    lines = odata.splitlines()   
   
    for line in lines:
        slist = line.split()
        if slist:
            if slist[0] == 'v':
                vertex = np.array(slist[1:], dtype=float)
                vertices.append(vertex)
            elif slist[0] == 'f':
                face = []
                for k in range(1, len(slist)):
                    face.append([int(s) for s in slist[k].replace('//','/').split('/')])
                if len(face) > 3: # triangulate the n-polyonal face, n>3
                    faces.extend([[face[0][0]-1, face[k][0]-1, face[k+1][0]-1] for k in range(1, len(face)-1)])
                else:    
                    faces.append([face[j][0]-1 for j in range(len(face))])
            else: pass
    
    
    return np.array(vertices), np.array(faces)   
In [3]:
cmat = np.loadtxt('icbm_fiber_mat.txt')
nodes = np.loadtxt('fs_region_centers_68_sort.txt')

labels=[]
with open("freesurfer_regions_68_sort_full.txt", "r") as f:
    for line in f:
        labels.append(line.strip('\n'))
In [4]:
[source, target] = np.nonzero(np.triu(cmat)>0.01)

nodes_x = nodes[:,0]
nodes_y = nodes[:,1]
nodes_z = nodes[:,2]

edge_x = []
edge_y = []
edge_z = []
for s, t in zip(source, target):
    edge_x += [nodes_x[s], nodes_x[t]]
    edge_y += [nodes_y[s], nodes_y[t]]
    edge_z += [nodes_z[s], nodes_z[t]]
In [5]:
with open("lh.pial.obj", "r") as f:
    obj_data = f.read()
[vertices, faces] = obj_data_to_mesh3d(obj_data)

vert_x, vert_y, vert_z = vertices[:,:3].T
face_i, face_j, face_k = faces.T
In [6]:
fig = go.Figure()

fig.add_trace(go.Mesh3d(x=vert_x, y=vert_y, z=vert_z, i=face_i, j=face_j, k=face_k,
                        color='pink', opacity=0.25, name='', showscale=False, hoverinfo='none'))

fig.add_trace(go.Scatter3d(x=nodes_x, y=nodes_y, z=nodes_z, text=labels,
                           mode='markers', hoverinfo='text', name='Nodes'))
fig.add_trace(go.Scatter3d(x=edge_x, y=edge_y, z=edge_z,
                           mode='lines', hoverinfo='none', name='Edges'))

fig.update_layout(
    scene=dict(
        xaxis=dict(showticklabels=False, visible=False),
        yaxis=dict(showticklabels=False, visible=False),
        zaxis=dict(showticklabels=False, visible=False),
    ),
    width=800, height=600
)

fig.show()