@@ -280,11 +280,17 @@ def random_packed_structure(
280
280
diameter = get_diameter (composition )
281
281
print (f"Using random pack diameter of { diameter } " )
282
282
283
+ # Ensure cell has batch dimension [1, 3, 3] if it doesn't already
284
+ if cell .ndim == 2 :
285
+ cell = cell .unsqueeze (0 ) # Add batch dimension
286
+
283
287
# Perform overlap minimization if diameter is specified
284
288
if diameter is not None :
285
289
print ("Reduce atom overlap using the soft_sphere potential" )
286
290
# Convert fractional to cartesian coordinates
287
- positions_cart = torch .matmul (positions , cell )
291
+ positions_cart = torch .matmul (
292
+ positions , cell .squeeze (0 )
293
+ ) # Use first (and only) batch
288
294
289
295
# Initialize soft sphere potential calculator
290
296
model = SoftSphereModel (
@@ -296,23 +302,30 @@ def random_packed_structure(
296
302
)
297
303
298
304
# Dummy atomic numbers
299
- atomic_numbers = torch .ones_like (positions_cart , device = device , dtype = torch .int )
305
+ atomic_numbers = torch .ones (N_atoms , device = device , dtype = torch .int )
306
+
307
+ # Create batch tensor for single system
308
+ batch = torch .zeros (N_atoms , device = device , dtype = torch .long )
300
309
301
310
# Set up FIRE optimizer with unit masses
302
311
state = ts .SimState (
303
312
positions = positions_cart ,
304
313
masses = torch .ones (N_atoms , device = device , dtype = dtype ),
305
314
atomic_numbers = atomic_numbers ,
306
- cell = cell ,
315
+ cell = cell , # Keep batch dimension
307
316
pbc = True ,
317
+ batch = batch ,
308
318
)
309
319
fire_init , fire_update = fire (model = model )
310
320
state = fire_init (state )
311
321
print (f"Initial energy: { state .energy .item ():.4f} " )
312
322
# Run FIRE optimization until convergence or max iterations
313
323
for _step in range (max_iter ):
314
324
# Check if minimum distance criterion is met (95% of target diameter)
315
- if min_distance (state .positions , cell , distance_tolerance ) > diameter * 0.95 :
325
+ if (
326
+ min_distance (state .positions , cell .squeeze (0 ), distance_tolerance )
327
+ > diameter * 0.95
328
+ ):
316
329
break
317
330
318
331
if log is not None :
@@ -321,6 +334,20 @@ def random_packed_structure(
321
334
state = fire_update (state )
322
335
323
336
print (f"Final energy: { state .energy .item ():.4f} " )
337
+ else :
338
+ # If no optimization, still create a proper state with batch dimensions
339
+ positions_cart = torch .matmul (positions , cell .squeeze (0 ))
340
+ atomic_numbers = torch .ones (N_atoms , device = device , dtype = torch .int )
341
+ batch = torch .zeros (N_atoms , device = device , dtype = torch .long )
342
+
343
+ state = ts .SimState (
344
+ positions = positions_cart ,
345
+ masses = torch .ones (N_atoms , device = device , dtype = dtype ),
346
+ atomic_numbers = atomic_numbers ,
347
+ cell = cell , # Keep batch dimension
348
+ pbc = True ,
349
+ batch = batch ,
350
+ )
324
351
325
352
if log is not None :
326
353
return state , log
@@ -408,11 +435,17 @@ def random_packed_structure_multi(
408
435
diameter_matrix = get_diameter_matrix (composition , device = device , dtype = dtype )
409
436
print (f"Using random pack diameter matrix:\n { diameter_matrix .cpu ().numpy ()} " )
410
437
438
+ # Ensure cell has batch dimension [1, 3, 3] if it doesn't already
439
+ if cell .ndim == 2 :
440
+ cell = cell .unsqueeze (0 ) # Add batch dimension
441
+
411
442
# Perform overlap minimization if diameter matrix is specified
412
443
if diameter_matrix is not None :
413
444
print ("Reduce atom overlap using the soft_sphere potential" )
414
445
# Convert fractional to cartesian coordinates
415
- positions_cart = torch .matmul (positions , cell )
446
+ positions_cart = torch .matmul (
447
+ positions , cell .squeeze (0 )
448
+ ) # Use first (and only) batch
416
449
417
450
# Initialize multi-species soft sphere potential calculator
418
451
model = SoftSphereMultiModel (
@@ -425,14 +458,18 @@ def random_packed_structure_multi(
425
458
)
426
459
427
460
# Dummy atomic numbers
428
- atomic_numbers = torch .ones_like (positions_cart , device = device , dtype = torch .int )
461
+ atomic_numbers = torch .ones (N_atoms , device = device , dtype = torch .int )
462
+
463
+ # Create batch tensor for single system
464
+ batch = torch .zeros (N_atoms , device = device , dtype = torch .long )
429
465
430
466
state_dict = ts .SimState (
431
467
positions = positions_cart ,
432
468
masses = torch .ones (N_atoms , device = device , dtype = dtype ),
433
469
atomic_numbers = atomic_numbers ,
434
- cell = cell ,
470
+ cell = cell , # Keep batch dimension
435
471
pbc = True ,
472
+ batch = batch ,
436
473
)
437
474
# Set up FIRE optimizer with unit masses for all atoms
438
475
fire_init , fire_update = fire (model = model )
@@ -441,11 +478,25 @@ def random_packed_structure_multi(
441
478
# Run FIRE optimization until convergence or max iterations
442
479
for _step in range (max_iter ):
443
480
# Check if minimum distance criterion is met (95% of smallest target diameter)
444
- min_dist = min_distance (state .positions , cell , distance_tolerance )
481
+ min_dist = min_distance (state .positions , cell . squeeze ( 0 ) , distance_tolerance )
445
482
if min_dist > diameter_matrix .min () * 0.95 :
446
483
break
447
484
state = fire_update (state )
448
485
print (f"Final energy: { state .energy .item ():.4f} " )
486
+ else :
487
+ # If no optimization, still create a proper state with batch dimensions
488
+ positions_cart = torch .matmul (positions , cell .squeeze (0 ))
489
+ atomic_numbers = torch .ones (N_atoms , device = device , dtype = torch .int )
490
+ batch = torch .zeros (N_atoms , device = device , dtype = torch .long )
491
+
492
+ state = ts .SimState (
493
+ positions = positions_cart ,
494
+ masses = torch .ones (N_atoms , device = device , dtype = dtype ),
495
+ atomic_numbers = atomic_numbers ,
496
+ cell = cell , # Keep batch dimension
497
+ pbc = True ,
498
+ batch = batch ,
499
+ )
449
500
450
501
return state
451
502
@@ -472,8 +523,8 @@ def valid_subcell(
472
523
Args:
473
524
positions: Atomic positions tensor of shape [n_atoms, 3], where each row contains
474
525
the (x,y,z) coordinates of an atom.
475
- cell: Unit cell tensor of shape [3, 3] containing the three lattice vectors that
476
- define the periodic boundary conditions.
526
+ cell: Unit cell tensor of shape [3, 3] or [1, 3, 3] containing the three lattice
527
+ vectors that define the periodic boundary conditions.
477
528
initial_energy: Total energy of the structure before relaxation, in eV.
478
529
final_energy: Total energy of the structure after relaxation, in eV.
479
530
e_tol: Energy tolerance for comparing initial and final energies, in eV.
@@ -510,8 +561,9 @@ def valid_subcell(
510
561
return False
511
562
512
563
# Check minimum interatomic distances to detect atomic fusion
513
- # Uses periodic boundary conditions via min_distance function
514
- min_dist = min_distance (positions , cell , distance_tolerance )
564
+ # Handle both batched and unbatched cell tensors
565
+ cell_for_min_dist = cell .squeeze (0 ) if cell .ndim == 3 else cell
566
+ min_dist = min_distance (positions , cell_for_min_dist , distance_tolerance )
515
567
if min_dist < fusion_distance :
516
568
print ("Bad structure! Fusion found." )
517
569
return False
@@ -645,15 +697,18 @@ def subcells_to_structures(
645
697
candidates: List of (ids, lower_bound, upper_bound)
646
698
tuples from get_subcells_to_crystallize
647
699
fractional_positions: Fractional coordinates of atoms
648
- cell: Unit cell tensor
700
+ cell: Unit cell tensor of shape [3, 3] or [1, 3, 3]
649
701
species: List of atomic species symbols
650
702
651
703
Returns:
652
704
list[tuple[torch.Tensor, torch.Tensor, list[str]]]: Each tuple contains:
653
705
- fractional_positions: Fractional coordinates of atoms
654
- - cell: Unit cell tensor
706
+ - cell: Unit cell tensor with proper batch dimensions
655
707
- species: atomic species symbols
656
708
"""
709
+ # Handle both batched and unbatched cell tensors
710
+ cell_2d = cell .squeeze (0 ) if cell .ndim == 3 else cell
711
+
657
712
list_subcells = []
658
713
for ids , lower_bound , upper_bound in candidates :
659
714
# Get positions of atoms in this subcell
@@ -666,7 +721,11 @@ def subcells_to_structures(
666
721
new_frac_pos = new_frac_pos / (upper_bound - lower_bound )
667
722
668
723
# Calculate new cell parameters
669
- new_cell = cell * (upper_bound - lower_bound ).unsqueeze (0 )
724
+ new_cell = cell_2d * (upper_bound - lower_bound ).unsqueeze (0 )
725
+
726
+ # Add batch dimension to maintain consistency
727
+ if cell .ndim == 3 : # Original cell had batch dimension
728
+ new_cell = new_cell .unsqueeze (0 )
670
729
671
730
# Get species for these atoms and convert tensor indices to list/numpy array
672
731
# before indexing species list
@@ -721,12 +780,18 @@ def get_unit_cell_relaxed_structure(
721
780
tuple containing:
722
781
- UnitCellFIREState: Final state containing relaxed positions, cell and more
723
782
- dict: Logger with energy and stress trajectories
724
- - float: Final energy in eV
725
- - float: Final pressure in eV/ų
726
783
"""
727
784
# Get device and dtype from model
728
785
device , dtype = model .device , model .dtype
729
786
787
+ # Ensure state has proper batch dimensions
788
+ if state .cell .ndim == 2 :
789
+ state .cell = state .cell .unsqueeze (0 ) # Add batch dimension
790
+
791
+ # Ensure batch tensor exists and has correct shape
792
+ if state .batch is None :
793
+ state .batch = torch .zeros (len (state .positions ), device = device , dtype = torch .long )
794
+
730
795
logger = {
731
796
"energy" : torch .zeros ((max_iter , state .n_batches ), device = device , dtype = dtype ),
732
797
"stress" : torch .zeros (
@@ -769,7 +834,7 @@ def step_fn(
769
834
f"Final energy: { [f'{ e :.4f} ' for e in final_energy ]} eV, "
770
835
f"Final pressure: { [f'{ p :.4f} ' for p in final_pressure ]} eV/A^3"
771
836
)
772
- return state , logger , final_energy , final_pressure
837
+ return state , logger
773
838
774
839
775
840
def get_relaxed_structure (
@@ -791,12 +856,18 @@ def get_relaxed_structure(
791
856
tuple containing:
792
857
- FIREState: Final state containing relaxed positions and other quantities
793
858
- dict: Logger with energy trajectory
794
- - float: Final energy in eV
795
- - float: Final pressure in eV/ų
796
859
"""
797
860
# Get device and dtype from model
798
861
device , dtype = model .device , model .dtype
799
862
863
+ # Ensure state has proper batch dimensions
864
+ if state .cell .ndim == 2 :
865
+ state .cell = state .cell .unsqueeze (0 ) # Add batch dimension
866
+
867
+ # Ensure batch tensor exists and has correct shape
868
+ if state .batch is None :
869
+ state .batch = torch .zeros (len (state .positions ), device = device , dtype = torch .long )
870
+
800
871
logger = {"energy" : torch .zeros ((max_iter , 1 ), device = device , dtype = dtype )}
801
872
802
873
results = model (state )
@@ -816,9 +887,7 @@ def step_fn(idx: int, state: FireState, logger: dict) -> tuple[FireState, dict]:
816
887
817
888
# Get final results
818
889
model .compute_stress = True
819
- final_results = model (
820
- positions = state .positions , cell = state .cell , atomic_numbers = state .atomic_numbers
821
- )
890
+ final_results = model (state )
822
891
823
892
final_energy = final_results ["energy" ].item ()
824
893
final_stress = final_results ["stress" ]
@@ -827,4 +896,4 @@ def step_fn(idx: int, state: FireState, logger: dict) -> tuple[FireState, dict]:
827
896
f"Final energy: { final_energy :.4f} eV, "
828
897
f"Final pressure: { final_pressure :.4f} eV/A^3"
829
898
)
830
- return state , logger , final_energy , final_pressure
899
+ return state , logger
0 commit comments